From 3f78ce4f7b4faab40ba3646edaeb864a66e3d99b Mon Sep 17 00:00:00 2001 From: Claude Date: Sun, 23 Nov 2025 12:25:51 +0000 Subject: [PATCH] Modernize Python type hints to use built-in generics and union syntax - Replace List[] -> list[], Dict[] -> dict[], Tuple[] -> tuple[], Set[] -> set[] - Replace Optional[X] -> X | None - Use Union["Type", None] for forward reference cases to satisfy TC010 rule - Clean up unused typing imports (List, Dict, Tuple, Set, Optional) --- .../elicitations/elicitation_forms_server.py | 5 +- .../elicitations/game_character_handler.py | 4 +- examples/new-api/simple_llm_advanced.py | 4 +- examples/openapi/openapi_mcp_server.py | 10 +- examples/tensorzero/image_demo.py | 4 +- hatch_build.py | 4 +- scripts/event_viewer.py | 9 +- scripts/gen_schema.py | 8 +- src/fast_agent/agents/llm_agent.py | 20 +-- src/fast_agent/agents/llm_decorator.py | 106 ++++++++-------- src/fast_agent/agents/mcp_agent.py | 47 ++++--- src/fast_agent/agents/tool_agent.py | 10 +- src/fast_agent/agents/workflow/chain_agent.py | 22 ++-- .../agents/workflow/evaluator_optimizer.py | 14 +-- .../agents/workflow/iterative_planner.py | 22 ++-- .../agents/workflow/orchestrator_models.py | 11 +- .../agents/workflow/parallel_agent.py | 38 +++--- .../agents/workflow/router_agent.py | 20 +-- src/fast_agent/cli/commands/auth.py | 24 ++-- src/fast_agent/cli/commands/check_config.py | 9 +- src/fast_agent/cli/commands/server_helpers.py | 4 +- src/fast_agent/cli/commands/url_parser.py | 14 +-- src/fast_agent/config.py | 20 +-- src/fast_agent/context.py | 18 +-- src/fast_agent/core/agent_app.py | 14 +-- src/fast_agent/core/core_app.py | 8 +- src/fast_agent/core/direct_decorators.py | 87 +++++++------ src/fast_agent/core/direct_factory.py | 22 ++-- src/fast_agent/core/executor/executor.py | 22 ++-- src/fast_agent/core/executor/task_registry.py | 12 +- .../core/executor/workflow_signal.py | 16 +-- src/fast_agent/core/fastagent.py | 107 ++++++++-------- src/fast_agent/core/logging/events.py | 14 +-- .../core/logging/json_serializer.py | 6 +- src/fast_agent/core/logging/listeners.py | 8 +- src/fast_agent/core/logging/logger.py | 6 +- src/fast_agent/core/logging/transport.py | 10 +- src/fast_agent/core/validation.py | 16 +-- src/fast_agent/event_progress.py | 11 +- src/fast_agent/history/history_exporter.py | 4 +- .../human_input/elicitation_handler.py | 4 +- .../human_input/elicitation_state.py | 5 +- src/fast_agent/human_input/form_fields.py | 118 +++++++++--------- src/fast_agent/human_input/simple_form.py | 18 +-- src/fast_agent/interfaces.py | 35 +++--- src/fast_agent/llm/fastagent_llm.py | 71 +++++------ src/fast_agent/llm/internal/passthrough.py | 12 +- src/fast_agent/llm/internal/playback.py | 10 +- src/fast_agent/llm/internal/slow.py | 4 +- src/fast_agent/llm/memory.py | 30 ++--- src/fast_agent/llm/model_database.py | 18 +-- src/fast_agent/llm/model_factory.py | 8 +- src/fast_agent/llm/model_info.py | 16 +-- src/fast_agent/llm/prompt_utils.py | 20 +-- .../llm/provider/anthropic/cache_planner.py | 11 +- .../llm/provider/anthropic/llm_anthropic.py | 48 +++---- .../multipart_converter_anthropic.py | 8 +- .../llm/provider/bedrock/bedrock_utils.py | 70 +++++------ .../llm/provider/bedrock/llm_bedrock.py | 82 ++++++------ .../bedrock/multipart_converter_bedrock.py | 4 +- .../llm/provider/google/google_converter.py | 44 +++---- .../llm/provider/google/llm_google_native.py | 37 +++--- .../llm/provider/openai/llm_openai.py | 30 ++--- .../provider/openai/llm_openai_compatible.py | 6 +- .../provider/openai/llm_tensorzero_openai.py | 8 +- .../openai/multipart_converter_openai.py | 28 ++--- .../llm/provider/openai/openai_multipart.py | 8 +- .../llm/provider/openai/openai_utils.py | 4 +- .../llm/provider/openai/responses.py | 7 +- src/fast_agent/llm/provider_key_manager.py | 6 +- src/fast_agent/llm/request_params.py | 8 +- src/fast_agent/llm/sampling_converter.py | 7 +- src/fast_agent/llm/usage_tracking.py | 14 +-- src/fast_agent/mcp/elicitation_factory.py | 4 +- src/fast_agent/mcp/helpers/content_helpers.py | 20 +-- .../mcp/helpers/server_config_helpers.py | 4 +- src/fast_agent/mcp/hf_auth.py | 7 +- src/fast_agent/mcp/interfaces.py | 13 +- src/fast_agent/mcp/mcp_aggregator.py | 68 +++++----- src/fast_agent/mcp/mcp_connection_manager.py | 14 +-- src/fast_agent/mcp/mcp_content.py | 12 +- src/fast_agent/mcp/prompt.py | 8 +- src/fast_agent/mcp/prompt_message_extended.py | 18 +-- src/fast_agent/mcp/prompt_render.py | 3 +- src/fast_agent/mcp/prompt_serialization.py | 31 +++-- src/fast_agent/mcp/prompts/prompt_helpers.py | 14 +-- src/fast_agent/mcp/prompts/prompt_load.py | 8 +- src/fast_agent/mcp/prompts/prompt_server.py | 18 +-- src/fast_agent/mcp/prompts/prompt_template.py | 38 +++--- src/fast_agent/mcp/resource_utils.py | 7 +- src/fast_agent/mcp/server/agent_server.py | 4 +- src/fast_agent/mcp/skybridge.py | 7 +- src/fast_agent/mcp/ui_mixin.py | 26 ++-- src/fast_agent/mcp_server_registry.py | 9 +- src/fast_agent/skills/registry.py | 28 ++--- src/fast_agent/tools/elicitation.py | 36 +++--- src/fast_agent/tools/shell_runtime.py | 10 +- src/fast_agent/types/conversation_summary.py | 15 ++- src/fast_agent/types/message_search.py | 18 +-- src/fast_agent/ui/console_display.py | 30 ++--- src/fast_agent/ui/elicitation_form.py | 24 ++-- src/fast_agent/ui/enhanced_prompt.py | 18 +-- src/fast_agent/ui/interactive_prompt.py | 20 +-- src/fast_agent/ui/markdown_truncator.py | 30 ++--- src/fast_agent/ui/mcp_ui_utils.py | 12 +- src/fast_agent/ui/mermaid_utils.py | 7 +- src/fast_agent/ui/notification_tracker.py | 15 ++- src/fast_agent/ui/rich_progress.py | 4 +- src/fast_agent/ui/streaming_buffer.py | 22 ++-- src/fast_agent/ui/usage_display.py | 8 +- src/fast_agent/workflow_telemetry.py | 4 +- .../e2e/bedrock/test_dynamic_capabilities.py | 3 +- tests/e2e/bedrock/test_e2e_smoke_bedrock.py | 6 +- tests/e2e/smoke/base/test_e2e_smoke.py | 4 +- .../acp/test_acp_slash_commands.py | 10 +- .../elicitation_test_server_advanced.py | 5 +- .../elicitation/test_elicitation_handler.py | 4 +- .../test_elicitation_integration.py | 4 +- .../elicitation/testing_handlers.py | 4 +- .../test_prompt_server_integration.py | 4 +- .../test_load_prompt_templates.py | 4 +- tests/integration/tool_loop/test_tool_loop.py | 5 +- .../agents/test_mcp_agent_local_tools.py | 4 +- .../llm/providers/test_llm_azure.py | 11 +- .../llm/providers/test_llm_tensorzero_unit.py | 7 +- .../test_multipart_converter_google.py | 6 +- .../llm/test_cache_control_application.py | 4 +- .../llm/test_cache_walking_real_messages.py | 10 +- .../fast_agent/llm/test_prepare_arguments.py | 8 +- tests/unit/fast_agent/llm/test_structured.py | 4 +- tests/unit/fast_agent/mcp/test_ui_mixin.py | 5 +- .../fast_agent/tools/test_shell_runtime.py | 8 +- 132 files changed, 1153 insertions(+), 1210 deletions(-) diff --git a/examples/mcp/elicitations/elicitation_forms_server.py b/examples/mcp/elicitations/elicitation_forms_server.py index 1b58a893f..32b24f5fc 100644 --- a/examples/mcp/elicitations/elicitation_forms_server.py +++ b/examples/mcp/elicitations/elicitation_forms_server.py @@ -7,7 +7,6 @@ import logging import sys -from typing import Optional from mcp import ReadResourceResult from mcp.server.elicitation import ( @@ -38,13 +37,13 @@ async def event_registration() -> ReadResourceResult: class EventRegistration(BaseModel): name: str = Field(description="Your full name", min_length=2, max_length=100) email: str = Field(description="Your email address", json_schema_extra={"format": "email"}) - company_website: Optional[str] = Field( + company_website: str | None = Field( None, description="Your company website (optional)", json_schema_extra={"format": "uri"} ) event_date: str = Field( description="Which event date works for you?", json_schema_extra={"format": "date"} ) - dietary_requirements: Optional[str] = Field( + dietary_requirements: str | None = Field( None, description="Any dietary requirements? (optional)", max_length=200 ) diff --git a/examples/mcp/elicitations/game_character_handler.py b/examples/mcp/elicitations/game_character_handler.py index 03057c585..8606b95a5 100644 --- a/examples/mcp/elicitations/game_character_handler.py +++ b/examples/mcp/elicitations/game_character_handler.py @@ -8,7 +8,7 @@ import asyncio import random -from typing import TYPE_CHECKING, Any, Dict +from typing import TYPE_CHECKING, Any from mcp.shared.context import RequestContext from mcp.types import ElicitRequestParams, ElicitResult @@ -35,7 +35,7 @@ async def game_character_elicitation_handler( if params.requestedSchema: properties = params.requestedSchema.get("properties", {}) - content: Dict[str, Any] = {} + content: dict[str, Any] = {} console.print("\n[bold magenta]🎮 Character Creation Studio 🎮[/bold magenta]\n") diff --git a/examples/new-api/simple_llm_advanced.py b/examples/new-api/simple_llm_advanced.py index 5c29f0c5c..b217644e5 100644 --- a/examples/new-api/simple_llm_advanced.py +++ b/examples/new-api/simple_llm_advanced.py @@ -1,5 +1,5 @@ import asyncio -from typing import Any, Dict +from typing import Any from mcp.server.fastmcp.tools.base import Tool as FastMCPTool @@ -52,7 +52,7 @@ def calculate(operation: str, a: float, b: float) -> float: # Example 3: Complex async tool with side effects -async def send_email(to: str, subject: str, body: str) -> Dict[str, Any]: +async def send_email(to: str, subject: str, body: str) -> dict[str, Any]: """Send an email (mock implementation). Args: diff --git a/examples/openapi/openapi_mcp_server.py b/examples/openapi/openapi_mcp_server.py index a408d7de4..0074db64b 100644 --- a/examples/openapi/openapi_mcp_server.py +++ b/examples/openapi/openapi_mcp_server.py @@ -5,7 +5,7 @@ import argparse import logging from pathlib import Path -from typing import Any, Dict +from typing import Any import httpx import yaml @@ -18,11 +18,11 @@ class ApiCallRequest(BaseModel): method: str = Field(..., description="HTTP method to use, e.g. GET or POST.") path: str = Field(..., description="Endpoint path, such as /pets or pets.") - query: Dict[str, Any] | None = Field( + query: dict[str, Any] | None = Field( default=None, description="Optional query string parameters, keyed by name." ) body: Any | None = Field(default=None, description="Optional JSON request body.") - headers: Dict[str, str] | None = Field( + headers: dict[str, str] | None = Field( default=None, description="Optional HTTP headers to include with the request." ) timeout: float | None = Field( @@ -68,7 +68,7 @@ def build_server(spec_text: str, base_url: str | None, server_name: str) -> Fast "and optional query parameters, JSON body, or headers." ), ) - async def call_openapi_endpoint(request: ApiCallRequest) -> Dict[str, Any]: + async def call_openapi_endpoint(request: ApiCallRequest) -> dict[str, Any]: if not base_url: raise RuntimeError("The OpenAPI specification does not define a server URL to call.") @@ -89,7 +89,7 @@ async def call_openapi_endpoint(request: ApiCallRequest) -> Dict[str, Any]: except ValueError: payload = None - result: Dict[str, Any] = { + result: dict[str, Any] = { "status_code": response.status_code, "headers": dict(response.headers), } diff --git a/examples/tensorzero/image_demo.py b/examples/tensorzero/image_demo.py index fe7748b7f..3442e1699 100644 --- a/examples/tensorzero/image_demo.py +++ b/examples/tensorzero/image_demo.py @@ -2,7 +2,7 @@ import base64 import mimetypes from pathlib import Path -from typing import List, Union +from typing import Union from mcp.types import ImageContent, TextContent @@ -37,7 +37,7 @@ request_params=RequestParams(template_vars=MY_T0_SYSTEM_VARS), ) async def main(): - content_parts: List[Union[TextContent, ImageContent]] = [] + content_parts: list[Union[TextContent, ImageContent]] = [] content_parts.append(TextContent(type="text", text=TEXT_PROMPT)) for file_path in LOCAL_IMAGE_FILES: diff --git a/hatch_build.py b/hatch_build.py index 25be6f4af..56cebfc7c 100644 --- a/hatch_build.py +++ b/hatch_build.py @@ -2,7 +2,7 @@ import shutil from pathlib import Path -from typing import Any, Dict +from typing import Any from hatchling.builders.hooks.plugin.interface import BuildHookInterface @@ -10,7 +10,7 @@ class CustomBuildHook(BuildHookInterface): """Custom build hook to copy examples to resources structure.""" - def initialize(self, version: str, build_data: Dict[str, Any]) -> None: + def initialize(self, version: str, build_data: dict[str, Any]) -> None: """Copy examples from root to resources structure.""" # Clear existing resources/examples directory for clean build resources_examples_dir = Path(self.root) / "src/fast_agent/resources/examples" diff --git a/scripts/event_viewer.py b/scripts/event_viewer.py index 58be51002..b6f78cf7d 100755 --- a/scripts/event_viewer.py +++ b/scripts/event_viewer.py @@ -7,7 +7,6 @@ import tty from datetime import datetime from pathlib import Path -from typing import List, Optional import typer from rich.console import Console @@ -33,13 +32,13 @@ def get_key() -> str: class EventDisplay: """Display MCP events from a log file.""" - def __init__(self, events: List[Event]) -> None: + def __init__(self, events: list[Event]) -> None: self.events = events self.total = len(events) self.current = 0 - self.current_iteration: Optional[int] = None + self.current_iteration: int | None = None self.tool_calls = 0 - self.progress_events: List[ProgressEvent] = [] + self.progress_events: list[ProgressEvent] = [] self._process_current() def next(self, steps: int = 1) -> None: @@ -154,7 +153,7 @@ def render(self) -> Panel: return Panel(main_layout, title="MCP Event Viewer") -def load_events(path: Path) -> List[Event]: +def load_events(path: Path) -> list[Event]: """Load events from JSONL file.""" events = [] print(f"Loading events from {path}") # Debug diff --git a/scripts/gen_schema.py b/scripts/gen_schema.py index a24a326b7..c8aadde6d 100644 --- a/scripts/gen_schema.py +++ b/scripts/gen_schema.py @@ -15,7 +15,7 @@ import re import sys from pathlib import Path -from typing import Any, Dict, Tuple +from typing import Any import typer from pydantic import BaseModel @@ -26,7 +26,7 @@ console = Console() -def extract_model_info(content: str) -> Dict[str, Dict[str, str]]: +def extract_model_info(content: str) -> dict[str, dict[str, str]]: """ Extract docstrings for all models and their fields. Returns a dict mapping model names to their field descriptions. @@ -132,7 +132,7 @@ def create_mock_modules() -> None: def load_settings_class( file_path: Path, -) -> Tuple[type[BaseSettings], Dict[str, Dict[str, str]]]: +) -> tuple[type[BaseSettings], dict[str, dict[str, str]]]: """Load Settings class from a Python file.""" # Add src directory to Python path src_dir = file_path.parent.parent.parent / "src" @@ -164,7 +164,7 @@ def load_settings_class( def apply_descriptions_to_schema( - schema: Dict[str, Any], model_info: Dict[str, Dict[str, str]] + schema: dict[str, Any], model_info: dict[str, dict[str, str]] ) -> None: """Recursively apply descriptions to schema and all its nested models.""" if not isinstance(schema, dict): diff --git a/src/fast_agent/agents/llm_agent.py b/src/fast_agent/agents/llm_agent.py index f7d8b7b31..3ae007f23 100644 --- a/src/fast_agent/agents/llm_agent.py +++ b/src/fast_agent/agents/llm_agent.py @@ -8,7 +8,7 @@ - Chat display integration """ -from typing import Callable, List, Optional, Tuple +from typing import Callable from a2a.types import AgentCapabilities from mcp import Tool @@ -79,12 +79,12 @@ def workflow_telemetry(self, provider: WorkflowTelemetryProvider | None) -> None async def show_assistant_message( self, message: PromptMessageExtended, - bottom_items: List[str] | None = None, - highlight_items: str | List[str] | None = None, + bottom_items: list[str] | None = None, + highlight_items: str | list[str] | None = None, max_item_length: int | None = None, name: str | None = None, model: str | None = None, - additional_message: Optional[Text] = None, + additional_message: Text | None = None, ) -> None: """Display an assistant message with appropriate styling based on stop reason. @@ -99,7 +99,7 @@ async def show_assistant_message( """ # Determine display content based on stop reason if not provided - additional_segments: List[Text] = [] + additional_segments: list[Text] = [] # Generate additional message based on stop reason match message.stop_reason: @@ -234,13 +234,13 @@ def _should_stream(self) -> bool: async def generate_impl( self, - messages: List[PromptMessageExtended], + messages: list[PromptMessageExtended], request_params: RequestParams | None = None, - tools: List[Tool] | None = None, + tools: list[Tool] | None = None, ) -> PromptMessageExtended: """ Enhanced generate implementation that resets tool call tracking. - Messages are already normalized to List[PromptMessageExtended]. + Messages are already normalized to list[PromptMessageExtended]. """ if "user" == messages[-1].role: self.show_user_message(message=messages[-1]) @@ -296,10 +296,10 @@ async def generate_impl( async def structured_impl( self, - messages: List[PromptMessageExtended], + messages: list[PromptMessageExtended], model: type[ModelT], request_params: RequestParams | None = None, - ) -> Tuple[ModelT | None, PromptMessageExtended]: + ) -> tuple[ModelT | None, PromptMessageExtended]: if "user" == messages[-1].role: self.show_user_message(message=messages[-1]) diff --git a/src/fast_agent/agents/llm_decorator.py b/src/fast_agent/agents/llm_decorator.py index 6544eb4a1..305db5b06 100644 --- a/src/fast_agent/agents/llm_decorator.py +++ b/src/fast_agent/agents/llm_decorator.py @@ -9,12 +9,8 @@ TYPE_CHECKING, Any, Callable, - Dict, - List, Mapping, - Optional, Sequence, - Tuple, Type, TypeVar, Union, @@ -93,7 +89,7 @@ def remove_listener() -> None: return llm.add_stream_listener(listener) def add_tool_stream_listener( - self, listener: Callable[[str, Dict[str, Any] | None], None] + self, listener: Callable[[str, dict[str, Any] | None], None] ) -> Callable[[], None]: llm = getattr(self, "_llm", None) if not llm: @@ -125,8 +121,8 @@ class RemovedContentSummary: """Summary information about removed content for the last turn.""" model_name: str | None - counts: Dict[str, int] - category_mimes: Dict[str, Tuple[str, ...]] + counts: dict[str, int] + category_mimes: dict[str, tuple[str, ...]] alert_flags: frozenset[str] message: str @@ -135,10 +131,10 @@ class RemovedContentSummary: class _CallContext: """Internal helper for assembling an LLM call.""" - full_history: List[PromptMessageExtended] + full_history: list[PromptMessageExtended] call_params: RequestParams | None persist_history: bool - sanitized_messages: List[PromptMessageExtended] + sanitized_messages: list[PromptMessageExtended] summary: RemovedContentSummary | None @@ -166,13 +162,13 @@ def __init__( self.instruction = self.config.instruction # Agent-owned conversation state (PromptMessageExtended only) - self._message_history: List[PromptMessageExtended] = [] + self._message_history: list[PromptMessageExtended] = [] # Store the default request params from config self._default_request_params = self.config.default_request_params # Initialize the LLM to None (will be set by attach_llm) - self._llm: Optional[FastAgentLLMProtocol] = None + self._llm: FastAgentLLMProtocol | None = None self._initialized = False @property @@ -291,7 +287,7 @@ async def generate( Sequence[Union[str, PromptMessage, PromptMessageExtended]], ], request_params: RequestParams | None = None, - tools: List[Tool] | None = None, + tools: list[Tool] | None = None, ) -> PromptMessageExtended: """ Create a completion with the LLM using the provided messages. @@ -322,9 +318,9 @@ async def generate( async def generate_impl( self, - messages: List[PromptMessageExtended], + messages: list[PromptMessageExtended], request_params: RequestParams | None = None, - tools: List[Tool] | None = None, + tools: list[Tool] | None = None, ) -> PromptMessageExtended: """ Implementation method for generate. @@ -371,7 +367,7 @@ async def apply_prompt_template(self, prompt_result: GetPromptResult, prompt_nam async def apply_prompt( self, prompt: Union[str, GetPromptResult], - arguments: Dict[str, str] | None = None, + arguments: dict[str, str] | None = None, as_template: bool = False, namespace: str | None = None, ) -> str: @@ -418,7 +414,7 @@ async def structured( ], model: Type[ModelT], request_params: RequestParams | None = None, - ) -> Tuple[ModelT | None, PromptMessageExtended]: + ) -> tuple[ModelT | None, PromptMessageExtended]: """ Apply the prompt and return the result as a Pydantic model. @@ -448,10 +444,10 @@ async def structured( async def structured_impl( self, - messages: List[PromptMessageExtended], + messages: list[PromptMessageExtended], model: Type[ModelT], request_params: RequestParams | None = None, - ) -> Tuple[ModelT | None, PromptMessageExtended]: + ) -> tuple[ModelT | None, PromptMessageExtended]: """ Implementation method for structured. @@ -472,10 +468,10 @@ async def structured_impl( async def _generate_with_summary( self, - messages: List[PromptMessageExtended], + messages: list[PromptMessageExtended], request_params: RequestParams | None = None, - tools: List[Tool] | None = None, - ) -> Tuple[PromptMessageExtended, RemovedContentSummary | None]: + tools: list[Tool] | None = None, + ) -> tuple[PromptMessageExtended, RemovedContentSummary | None]: assert self._llm, "LLM is not attached" call_ctx = self._prepare_llm_call(messages, request_params) @@ -490,10 +486,10 @@ async def _generate_with_summary( async def _structured_with_summary( self, - messages: List[PromptMessageExtended], + messages: list[PromptMessageExtended], model: Type[ModelT], request_params: RequestParams | None = None, - ) -> Tuple[Tuple[ModelT | None, PromptMessageExtended], RemovedContentSummary | None]: + ) -> tuple[tuple[ModelT | None, PromptMessageExtended], RemovedContentSummary | None]: assert self._llm, "LLM is not attached" call_ctx = self._prepare_llm_call(messages, request_params) @@ -510,7 +506,7 @@ async def _structured_with_summary( return structured_result, call_ctx.summary def _prepare_llm_call( - self, messages: List[PromptMessageExtended], request_params: RequestParams | None = None + self, messages: list[PromptMessageExtended], request_params: RequestParams | None = None ) -> _CallContext: """Normalize template/history handling for both generate and structured.""" sanitized_messages, summary = self._sanitize_messages_for_llm(messages) @@ -535,7 +531,7 @@ def _prepare_llm_call( def _persist_history( self, - sanitized_messages: List[PromptMessageExtended], + sanitized_messages: list[PromptMessageExtended], assistant_message: PromptMessageExtended, ) -> None: """Persist the last turn unless explicitly disabled by control text.""" @@ -559,14 +555,14 @@ def _strip_removed_metadata(message: PromptMessageExtended) -> PromptMessageExte return msg_copy def _sanitize_messages_for_llm( - self, messages: List[PromptMessageExtended] - ) -> Tuple[List[PromptMessageExtended], RemovedContentSummary | None]: + self, messages: list[PromptMessageExtended] + ) -> tuple[list[PromptMessageExtended], RemovedContentSummary | None]: """Filter out content blocks that the current model cannot tokenize.""" if not messages: return [], None - removed_blocks: List[_RemovedBlock] = [] - sanitized_messages: List[PromptMessageExtended] = [] + removed_blocks: list[_RemovedBlock] = [] + sanitized_messages: list[PromptMessageExtended] = [] for message in messages: sanitized, removed = self._sanitize_message_for_llm(message) @@ -589,17 +585,17 @@ def _sanitize_messages_for_llm( def _sanitize_message_for_llm( self, message: PromptMessageExtended - ) -> Tuple[PromptMessageExtended, List[_RemovedBlock]]: + ) -> tuple[PromptMessageExtended, list[_RemovedBlock]]: """Return a sanitized copy of a message and any removed content blocks.""" msg_copy = message.model_copy(deep=True) - removed: List[_RemovedBlock] = [] + removed: list[_RemovedBlock] = [] msg_copy.content = self._filter_block_list( list(msg_copy.content or []), removed, source="message" ) if msg_copy.tool_results: - new_tool_results: Dict[str, CallToolResult] = {} + new_tool_results: dict[str, CallToolResult] = {} for tool_id, tool_result in msg_copy.tool_results.items(): original_blocks = list(tool_result.content or []) filtered_blocks = self._filter_block_list( @@ -635,12 +631,12 @@ def _sanitize_message_for_llm( def _filter_block_list( self, blocks: Sequence[ContentBlock], - removed: List[_RemovedBlock], + removed: list[_RemovedBlock], *, source: str, tool_id: str | None = None, - ) -> List[ContentBlock]: - kept: List[ContentBlock] = [] + ) -> list[ContentBlock]: + kept: list[ContentBlock] = [] for block in blocks or []: mime_type, category = self._extract_block_metadata(block) if self._block_supported(mime_type, category): @@ -679,7 +675,7 @@ def _block_supported(self, mime_type: str | None, category: str) -> bool: return False - def _extract_block_metadata(self, block: ContentBlock) -> Tuple[str | None, str]: + def _extract_block_metadata(self, block: ContentBlock) -> tuple[str | None, str]: """Infer the MIME type and high-level category for a content block.""" if isinstance(block, TextContent): return "text/plain", "text" @@ -711,9 +707,9 @@ def _extract_block_metadata(self, block: ContentBlock) -> Tuple[str | None, str] return None, "document" - def _build_error_channel_entries(self, removed: List[_RemovedBlock]) -> List[ContentBlock]: + def _build_error_channel_entries(self, removed: list[_RemovedBlock]) -> list[ContentBlock]: """Create informative entries for the error channel.""" - entries: List[ContentBlock] = [] + entries: list[ContentBlock] = [] model_name = self._llm.model_name if self._llm else None model_display = model_name or "current model" @@ -736,8 +732,8 @@ def _build_error_channel_entries(self, removed: List[_RemovedBlock]) -> List[Con return entries - def _build_metadata_entries(self, removed: List[_RemovedBlock]) -> List[ContentBlock]: - entries: List[ContentBlock] = [] + def _build_metadata_entries(self, removed: list[_RemovedBlock]) -> list[ContentBlock]: + entries: list[ContentBlock] = [] for item in removed: metadata_text = text_content( json.dumps( @@ -753,13 +749,13 @@ def _build_metadata_entries(self, removed: List[_RemovedBlock]) -> List[ContentB entries.append(metadata_text) return entries - def _build_removed_summary(self, removed: List[_RemovedBlock]) -> RemovedContentSummary | None: + def _build_removed_summary(self, removed: list[_RemovedBlock]) -> RemovedContentSummary | None: if not removed: return None counts = Counter(item.category for item in removed) - category_mimes: Dict[str, Tuple[str, ...]] = {} - mime_accumulator: Dict[str, set[str]] = defaultdict(set) + category_mimes: dict[str, tuple[str, ...]] = {} + mime_accumulator: dict[str, set[str]] = defaultdict(set) for item in removed: mime_accumulator[item.category].add(item.mime_type or "unknown") @@ -778,7 +774,7 @@ def _build_removed_summary(self, removed: List[_RemovedBlock]) -> RemovedContent model_display = model_name or "current model" category_order = ["vision", "document", "other", "text"] - segments: List[str] = [] + segments: list[str] = [] for category in category_order: if category not in counts: continue @@ -849,7 +845,7 @@ def _category_label(category: str) -> str: return "content" @property - def message_history(self) -> List[PromptMessageExtended]: + def message_history(self) -> list[PromptMessageExtended]: """ Return the agent's message history as PromptMessageExtended objects. @@ -862,7 +858,7 @@ def message_history(self) -> List[PromptMessageExtended]: return self._message_history @property - def template_messages(self) -> List[PromptMessageExtended]: + def template_messages(self) -> list[PromptMessageExtended]: """ Return the template prefix of the message history. @@ -871,9 +867,9 @@ def template_messages(self) -> List[PromptMessageExtended]: """ return [msg.model_copy(deep=True) for msg in self._template_prefix_messages()] - def _template_prefix_messages(self) -> List[PromptMessageExtended]: + def _template_prefix_messages(self) -> list[PromptMessageExtended]: """Return the leading messages marked as templates (non-copy).""" - prefix: List[PromptMessageExtended] = [] + prefix: list[PromptMessageExtended] = [] for msg in self._message_history: if msg.is_template: prefix.append(msg) @@ -906,20 +902,20 @@ def llm(self) -> FastAgentLLMProtocol: # --- Default MCP-facing convenience methods (no-op for plain LLM agents) --- - async def list_prompts(self, namespace: str | None = None) -> Mapping[str, List[Prompt]]: + async def list_prompts(self, namespace: str | None = None) -> Mapping[str, list[Prompt]]: """Default: no prompts; return empty mapping.""" return {} async def get_prompt( self, prompt_name: str, - arguments: Dict[str, str] | None = None, + arguments: dict[str, str] | None = None, namespace: str | None = None, ) -> GetPromptResult: """Default: prompts unsupported; return empty GetPromptResult.""" return GetPromptResult(description="", messages=[]) - async def list_resources(self, namespace: str | None = None) -> Mapping[str, List[str]]: + async def list_resources(self, namespace: str | None = None) -> Mapping[str, list[str]]: """Default: no resources; return empty mapping.""" return {} @@ -927,7 +923,7 @@ async def list_tools(self) -> ListToolsResult: """Default: no tools; return empty ListToolsResult.""" return ListToolsResult(tools=[]) - async def list_mcp_tools(self, namespace: str | None = None) -> Mapping[str, List[Tool]]: + async def list_mcp_tools(self, namespace: str | None = None) -> Mapping[str, list[Tool]]: """Default: no tools; return empty mapping.""" return {} @@ -1012,11 +1008,11 @@ async def run_tools(self, request: PromptMessageExtended) -> PromptMessageExtend async def show_assistant_message( self, message: PromptMessageExtended, - bottom_items: List[str] | None = None, - highlight_items: str | List[str] | None = None, + bottom_items: list[str] | None = None, + highlight_items: str | list[str] | None = None, max_item_length: int | None = None, name: str | None = None, model: str | None = None, - additional_message: Optional["Text"] = None, + additional_message: Union["Text", None] = None, ) -> None: pass diff --git a/src/fast_agent/agents/mcp_agent.py b/src/fast_agent/agents/mcp_agent.py index 6d687bf96..2be49412f 100644 --- a/src/fast_agent/agents/mcp_agent.py +++ b/src/fast_agent/agents/mcp_agent.py @@ -12,11 +12,8 @@ TYPE_CHECKING, Any, Callable, - Dict, Iterable, - List, Mapping, - Optional, Sequence, TypeVar, Union, @@ -108,7 +105,7 @@ def __init__( self.instruction = self.config.instruction self.executor = context.executor if context else None self.logger = get_logger(f"{__name__}.{self._name}") - manifests: List[SkillManifest] = list(getattr(self.config, "skill_manifests", []) or []) + manifests: list[SkillManifest] = list(getattr(self.config, "skill_manifests", []) or []) if not manifests and context and getattr(context, "skill_registry", None): try: manifests = list(context.skill_registry.load_manifests()) # type: ignore[assignment] @@ -116,7 +113,7 @@ def __init__( manifests = [] self._skill_manifests = list(manifests) - self._skill_map: Dict[str, SkillManifest] = { + self._skill_map: dict[str, SkillManifest] = { manifest.name: manifest for manifest in manifests } self._agent_skills_warning_shown = False @@ -244,7 +241,7 @@ async def shutdown(self) -> None: """ await self._aggregator.close() - async def get_server_status(self) -> Dict[str, ServerStatus]: + async def get_server_status(self) -> dict[str, ServerStatus]: """Expose server status details for UI and diagnostics consumers.""" if not self._aggregator: return {} @@ -313,7 +310,7 @@ async def _apply_instruction_templates(self) -> None: self.logger.debug(f"Applied instruction templates for agent {self._name}") def _format_server_instructions( - self, instructions_data: Dict[str, tuple[str | None, List[str]]] + self, instructions_data: dict[str, tuple[str | None, list[str]]] ) -> str: """ Format server instructions with XML tags and tool lists. @@ -494,7 +491,7 @@ def set_filesystem_runtime(self, runtime) -> None: ) async def call_tool( - self, name: str, arguments: Dict[str, Any] | None = None, tool_use_id: str | None = None + self, name: str, arguments: dict[str, Any] | None = None, tool_use_id: str | None = None ) -> CallToolResult: """ Call a tool by name with the given arguments. @@ -536,7 +533,7 @@ async def call_tool( return await self._aggregator.call_tool(name, arguments, tool_use_id) async def _call_human_input_tool( - self, arguments: Dict[str, Any] | None = None + self, arguments: dict[str, Any] | None = None ) -> CallToolResult: """ Handle human input via an elicitation form. @@ -593,7 +590,7 @@ async def _call_human_input_tool( async def get_prompt( self, prompt_name: str, - arguments: Dict[str, str] | None = None, + arguments: dict[str, str] | None = None, namespace: str | None = None, server_name: str | None = None, ) -> GetPromptResult: @@ -614,7 +611,7 @@ async def get_prompt( async def apply_prompt( self, prompt: Union[str, GetPromptResult], - arguments: Dict[str, str] | None = None, + arguments: dict[str, str] | None = None, as_template: bool = False, namespace: str | None = None, **_: Any, @@ -679,7 +676,7 @@ async def apply_prompt( async def get_embedded_resources( self, resource_uri: str, server_name: str | None = None - ) -> List[EmbeddedResource]: + ) -> list[EmbeddedResource]: """ Get a resource from an MCP server and return it as a list of embedded resources ready for use in prompts. @@ -697,7 +694,7 @@ async def get_embedded_resources( result: ReadResourceResult = await self._aggregator.get_resource(resource_uri, server_name) # Convert each resource content to an EmbeddedResource - embedded_resources: List[EmbeddedResource] = [] + embedded_resources: list[EmbeddedResource] = [] for resource_content in result.contents: embedded_resource = EmbeddedResource( type="resource", resource=resource_content, annotations=None @@ -749,7 +746,7 @@ async def with_resource( The agent's response as a string """ # Get the embedded resources - embedded_resources: List[EmbeddedResource] = await self.get_embedded_resources( + embedded_resources: list[EmbeddedResource] = await self.get_embedded_resources( resource_uri, namespace if namespace is not None else server_name ) @@ -1022,7 +1019,7 @@ async def apply_prompt_template(self, prompt_result: GetPromptResult, prompt_nam return await self._llm.apply_prompt_template(prompt_result, prompt_name) async def apply_prompt_messages( - self, prompts: List[PromptMessageExtended], request_params: RequestParams | None = None + self, prompts: list[PromptMessageExtended], request_params: RequestParams | None = None ) -> str: """ Apply a list of prompt messages and return the result. @@ -1040,7 +1037,7 @@ async def apply_prompt_messages( async def list_prompts( self, namespace: str | None = None, server_name: str | None = None - ) -> Mapping[str, List[mcp.types.Prompt]]: + ) -> Mapping[str, list[mcp.types.Prompt]]: """ List all prompts available to this agent, filtered by configuration. @@ -1062,7 +1059,7 @@ async def list_prompts( async def list_resources( self, namespace: str | None = None, server_name: str | None = None - ) -> Dict[str, List[str]]: + ) -> dict[str, list[str]]: """ List all resources available to this agent, filtered by configuration. @@ -1082,7 +1079,7 @@ async def list_resources( lambda resource: resource, ) - async def list_mcp_tools(self, namespace: str | None = None) -> Mapping[str, List[Tool]]: + async def list_mcp_tools(self, namespace: str | None = None) -> Mapping[str, list[Tool]]: """ List all tools available to this agent, grouped by server and filtered by configuration. @@ -1160,7 +1157,7 @@ async def agent_card(self) -> AgentCard: Return an A2A card describing this Agent """ - skills: List[AgentSkill] = [] + skills: list[AgentSkill] = [] tools: ListToolsResult = await self.list_tools() for tool in tools.tools: skills.append(await self.convert(tool)) @@ -1181,12 +1178,12 @@ async def agent_card(self) -> AgentCard: async def show_assistant_message( self, message: PromptMessageExtended, - bottom_items: List[str] | None = None, - highlight_items: str | List[str] | None = None, + bottom_items: list[str] | None = None, + highlight_items: str | list[str] | None = None, max_item_length: int | None = None, name: str | None = None, model: str | None = None, - additional_message: Optional["Text"] = None, + additional_message: Union["Text", None] = None, ) -> None: """ Display an assistant message with MCP servers in the bottom bar. @@ -1230,7 +1227,7 @@ async def show_assistant_message( additional_message=additional_message, ) - def _extract_servers_from_message(self, message: PromptMessageExtended) -> List[str]: + def _extract_servers_from_message(self, message: PromptMessageExtended) -> list[str]: """ Extract server names from tool calls in the message. @@ -1310,7 +1307,7 @@ async def convert(self, tool: Tool) -> AgentSkill: ) @property - def message_history(self) -> List[PromptMessageExtended]: + def message_history(self) -> list[PromptMessageExtended]: """ Return the agent's message history as PromptMessageExtended objects. @@ -1324,7 +1321,7 @@ def message_history(self) -> List[PromptMessageExtended]: return super().message_history @property - def usage_accumulator(self) -> Optional["UsageAccumulator"]: + def usage_accumulator(self) -> Union["UsageAccumulator", None]: """ Return the usage accumulator for tracking token usage across turns. diff --git a/src/fast_agent/agents/tool_agent.py b/src/fast_agent/agents/tool_agent.py index 9f76a7045..a2952eb84 100644 --- a/src/fast_agent/agents/tool_agent.py +++ b/src/fast_agent/agents/tool_agent.py @@ -1,4 +1,4 @@ -from typing import Any, Callable, Dict, List, Sequence +from typing import Any, Callable, Sequence from mcp.server.fastmcp.tools.base import Tool as FastMCPTool from mcp.types import CallToolResult, ListToolsResult, Tool @@ -76,13 +76,13 @@ def __init__( async def generate_impl( self, - messages: List[PromptMessageExtended], + messages: list[PromptMessageExtended], request_params: RequestParams | None = None, - tools: List[Tool] | None = None, + tools: list[Tool] | None = None, ) -> PromptMessageExtended: """ Generate a response using the LLM, and handle tool calls if necessary. - Messages are already normalized to List[PromptMessageExtended]. + Messages are already normalized to list[PromptMessageExtended]. """ if tools is None: tools = (await self.list_tools()).tools @@ -241,7 +241,7 @@ async def list_tools(self) -> ListToolsResult: """Return available tools for this agent. Overridable by subclasses.""" return ListToolsResult(tools=list(self._tool_schemas)) - async def call_tool(self, name: str, arguments: Dict[str, Any] | None = None) -> CallToolResult: + async def call_tool(self, name: str, arguments: dict[str, Any] | None = None) -> CallToolResult: """Execute a tool by name using local FastMCP tools. Overridable by subclasses.""" fast_tool = self._execution_tools.get(name) if not fast_tool: diff --git a/src/fast_agent/agents/workflow/chain_agent.py b/src/fast_agent/agents/workflow/chain_agent.py index b15d8ec7f..6bba92678 100644 --- a/src/fast_agent/agents/workflow/chain_agent.py +++ b/src/fast_agent/agents/workflow/chain_agent.py @@ -5,7 +5,7 @@ other agents, chaining their outputs together. """ -from typing import Any, List, Optional, Tuple, Type +from typing import Any, Type from mcp import Tool from mcp.types import TextContent @@ -35,9 +35,9 @@ def agent_type(self) -> AgentType: def __init__( self, config: AgentConfig, - agents: List[LlmAgent], + agents: list[LlmAgent], cumulative: bool = False, - context: Optional[Any] = None, + context: Any | None = None, **kwargs, ) -> None: """ @@ -56,9 +56,9 @@ def __init__( async def generate_impl( self, - messages: List[PromptMessageExtended], - request_params: Optional[RequestParams] = None, - tools: List[Tool] | None = None, + messages: list[PromptMessageExtended], + request_params: RequestParams | None = None, + tools: list[Tool] | None = None, ) -> PromptMessageExtended: """ Chain the request through multiple agents in sequence. @@ -85,10 +85,10 @@ async def generate_impl( return response # Track all responses in the chain - all_responses: List[PromptMessageExtended] = [] + all_responses: list[PromptMessageExtended] = [] # Initialize list for storing formatted results - final_results: List[str] = [] + final_results: list[str] = [] # Add the original request with XML tag request_text = ( @@ -128,10 +128,10 @@ async def generate_impl( async def structured_impl( self, - messages: List[PromptMessageExtended], + messages: list[PromptMessageExtended], model: Type[ModelT], - request_params: Optional[RequestParams] = None, - ) -> Tuple[ModelT | None, PromptMessageExtended]: + request_params: RequestParams | None = None, + ) -> tuple[ModelT | None, PromptMessageExtended]: """ Chain the request through multiple agents and parse the final response. diff --git a/src/fast_agent/agents/workflow/evaluator_optimizer.py b/src/fast_agent/agents/workflow/evaluator_optimizer.py index 0cc242b22..5065aafe5 100644 --- a/src/fast_agent/agents/workflow/evaluator_optimizer.py +++ b/src/fast_agent/agents/workflow/evaluator_optimizer.py @@ -8,7 +8,7 @@ """ from enum import Enum -from typing import Any, List, Optional, Tuple, Type +from typing import Any, Type from mcp import Tool from pydantic import BaseModel, Field @@ -48,7 +48,7 @@ class EvaluationResult(BaseModel): rating: QualityRating = Field(description="Quality rating of the response") feedback: str = Field(description="Specific feedback and suggestions for improvement") needs_improvement: bool = Field(description="Whether the output needs further improvement") - focus_areas: List[str] = Field( + focus_areas: list[str] = Field( default_factory=list, description="Specific areas to focus on in next iteration" ) @@ -74,7 +74,7 @@ def __init__( evaluator_agent: AgentProtocol, min_rating: QualityRating = QualityRating.GOOD, max_refinements: int = 3, - context: Optional[Any] = None, + context: Any | None = None, **kwargs, ) -> None: """ @@ -105,9 +105,9 @@ def __init__( async def generate_impl( self, - messages: List[PromptMessageExtended], + messages: list[PromptMessageExtended], request_params: RequestParams | None = None, - tools: List[Tool] | None = None, + tools: list[Tool] | None = None, ) -> PromptMessageExtended: """ Generate a response through evaluation-guided refinement. @@ -204,10 +204,10 @@ async def generate_impl( async def structured_impl( self, - messages: List[PromptMessageExtended], + messages: list[PromptMessageExtended], model: Type[ModelT], request_params: RequestParams | None = None, - ) -> Tuple[ModelT | None, PromptMessageExtended]: + ) -> tuple[ModelT | None, PromptMessageExtended]: """ Generate an optimized response and parse it into a structured format. diff --git a/src/fast_agent/agents/workflow/iterative_planner.py b/src/fast_agent/agents/workflow/iterative_planner.py index 62a51cc3e..3caa109ab 100644 --- a/src/fast_agent/agents/workflow/iterative_planner.py +++ b/src/fast_agent/agents/workflow/iterative_planner.py @@ -3,7 +3,7 @@ """ import asyncio -from typing import Any, Dict, List, Optional, Tuple, Type +from typing import Any, Type from mcp import Tool from mcp.types import TextContent @@ -166,9 +166,9 @@ def agent_type(self) -> AgentType: def __init__( self, config: AgentConfig, - agents: List[AgentProtocol], + agents: list[AgentProtocol], plan_iterations: int = -1, - context: Optional[Any] = None, + context: Any | None = None, **kwargs, ) -> None: """ @@ -185,7 +185,7 @@ def __init__( raise AgentConfigError("At least one worker agent must be provided") # Store agents by name for easier lookup - self.agents: Dict[str, AgentProtocol] = {} + self.agents: dict[str, AgentProtocol] = {} for agent in agents: agent_name = agent.name self.agents[agent_name] = agent @@ -237,9 +237,9 @@ async def shutdown(self) -> None: async def generate_impl( self, - messages: List[PromptMessageExtended], + messages: list[PromptMessageExtended], request_params: RequestParams | None = None, - tools: List[Tool] | None = None, + tools: list[Tool] | None = None, ) -> PromptMessageExtended: """ Execute an orchestrated plan to process the input. @@ -262,10 +262,10 @@ async def generate_impl( async def structured_impl( self, - messages: List[PromptMessageExtended], + messages: list[PromptMessageExtended], model: Type[ModelT], - request_params: Optional[RequestParams] = None, - ) -> Tuple[ModelT | None, PromptMessageExtended]: + request_params: RequestParams | None = None, + ) -> tuple[ModelT | None, PromptMessageExtended]: """ Execute an orchestration plan and parse the result into a structured format. @@ -386,7 +386,7 @@ async def _execute_step(self, step: Step, previous_result: PlanResult) -> Any: for task in step.tasks: tasks_by_agent[task.agent].append(task) - async def execute_agent_tasks(agent_name: str, agent_tasks: List) -> List[TaskWithResult]: + async def execute_agent_tasks(agent_name: str, agent_tasks: list) -> list[TaskWithResult]: """Execute all tasks for a single agent sequentially (preserves history)""" agent = self.agents.get(agent_name) assert agent is not None @@ -506,7 +506,7 @@ async def _get_next_step( logger.error(f"Failed to parse next step: {str(e)}") return None - def _validate_agent_names(self, plan: Plan) -> List[str]: + def _validate_agent_names(self, plan: Plan) -> list[str]: """ Validate all agent names in a plan before execution. diff --git a/src/fast_agent/agents/workflow/orchestrator_models.py b/src/fast_agent/agents/workflow/orchestrator_models.py index 6f4e7b353..432e9f482 100644 --- a/src/fast_agent/agents/workflow/orchestrator_models.py +++ b/src/fast_agent/agents/workflow/orchestrator_models.py @@ -1,4 +1,3 @@ -from typing import List from pydantic import BaseModel, ConfigDict, Field @@ -18,7 +17,7 @@ class Task(BaseModel): class ServerTask(Task): """An individual task that can be accomplished by one or more MCP servers""" - servers: List[str] = Field( + servers: list[str] = Field( description="Names of MCP servers that the LLM has access to for this task", default_factory=list, ) @@ -37,7 +36,7 @@ class Step(BaseModel): description: str = Field(description="Description of the step") - tasks: List[AgentTask] = Field( + tasks: list[AgentTask] = Field( description="Subtasks that can be executed in parallel", default_factory=list, ) @@ -52,7 +51,7 @@ class PlanningStep(Step): class Plan(BaseModel): """Plan generated by the orchestrator planner.""" - steps: List[Step] = Field( + steps: list[Step] = Field( description="List of steps to execute sequentially", default_factory=list, ) @@ -73,7 +72,7 @@ class StepResult(BaseModel): """Result of executing a step""" step: Step = Field(description="The step that was executed") - task_results: List[TaskWithResult] = Field( + task_results: list[TaskWithResult] = Field( description="Results of executing each task", default_factory=list ) result: str = Field(description="Result of executing the step", default="Step completed") @@ -94,7 +93,7 @@ class PlanResult(BaseModel): plan: Plan | None = None """The plan that was executed""" - step_results: List[StepResult] + step_results: list[StepResult] """Results of executing each step""" is_complete: bool = False diff --git a/src/fast_agent/agents/workflow/parallel_agent.py b/src/fast_agent/agents/workflow/parallel_agent.py index d4e4af519..310e67dc1 100644 --- a/src/fast_agent/agents/workflow/parallel_agent.py +++ b/src/fast_agent/agents/workflow/parallel_agent.py @@ -1,5 +1,5 @@ import asyncio -from typing import Any, List, Optional, Tuple +from typing import Any from mcp import Tool from mcp.types import TextContent @@ -31,7 +31,7 @@ def __init__( self, config: AgentConfig, fan_in_agent: AgentProtocol, - fan_out_agents: List[AgentProtocol], + fan_out_agents: list[AgentProtocol], include_request: bool = True, **kwargs, ) -> None: @@ -52,9 +52,9 @@ def __init__( async def generate_impl( self, - messages: List[PromptMessageExtended], - request_params: Optional[RequestParams] = None, - tools: List[Tool] | None = None, + messages: list[PromptMessageExtended], + request_params: RequestParams | None = None, + tools: list[Tool] | None = None, ) -> PromptMessageExtended: """ Execute fan-out agents in parallel and aggregate their results with the fan-in agent. @@ -69,12 +69,12 @@ async def generate_impl( tracer = trace.get_tracer(__name__) with tracer.start_as_current_span(f"Parallel: '{self._name}' generate"): - responses: List[PromptMessageExtended] = await self._execute_fan_out( + responses: list[PromptMessageExtended] = await self._execute_fan_out( messages, request_params ) # Extract the received message from the input - received_message: Optional[str] = messages[-1].all_text() if messages else None + received_message: str | None = messages[-1].all_text() if messages else None # Convert responses to strings for aggregation string_responses = [] @@ -92,7 +92,7 @@ async def generate_impl( # Use the fan-in agent to aggregate the responses return await self._fan_in_generate(formatted_prompt, request_params) - def _format_responses(self, responses: List[Any], message: Optional[str] = None) -> str: + def _format_responses(self, responses: list[Any], message: str | None = None) -> str: """ Format a list of responses for the fan-in agent. @@ -120,10 +120,10 @@ def _format_responses(self, responses: List[Any], message: Optional[str] = None) async def structured_impl( self, - messages: List[PromptMessageExtended], + messages: list[PromptMessageExtended], model: type[ModelT], - request_params: Optional[RequestParams] = None, - ) -> Tuple[ModelT | None, PromptMessageExtended]: + request_params: RequestParams | None = None, + ) -> tuple[ModelT | None, PromptMessageExtended]: """ Apply the prompt and return the result as a Pydantic model. @@ -140,12 +140,12 @@ async def structured_impl( tracer = trace.get_tracer(__name__) with tracer.start_as_current_span(f"Parallel: '{self._name}' generate"): - responses: List[PromptMessageExtended] = await self._execute_fan_out( + responses: list[PromptMessageExtended] = await self._execute_fan_out( messages, request_params ) # Extract the received message - received_message: Optional[str] = messages[-1].all_text() if messages else None + received_message: str | None = messages[-1].all_text() if messages else None # Convert responses to strings string_responses = [response.all_text() for response in responses] @@ -195,9 +195,9 @@ async def shutdown(self) -> None: async def _execute_fan_out( self, - messages: List[PromptMessageExtended], - request_params: Optional[RequestParams], - ) -> List[PromptMessageExtended]: + messages: list[PromptMessageExtended], + request_params: RequestParams | None, + ) -> list[PromptMessageExtended]: """ Run fan-out agents with telemetry so transports can surface progress. """ @@ -217,7 +217,7 @@ async def _run_agent(agent: AgentProtocol) -> PromptMessageExtended: async def _fan_in_generate( self, prompt: PromptMessageExtended, - request_params: Optional[RequestParams], + request_params: RequestParams | None, ) -> PromptMessageExtended: """ Aggregate fan-out output with telemetry. @@ -235,8 +235,8 @@ async def _fan_in_structured( self, prompt: PromptMessageExtended, model: type[ModelT], - request_params: Optional[RequestParams], - ) -> Tuple[ModelT | None, PromptMessageExtended]: + request_params: RequestParams | None, + ) -> tuple[ModelT | None, PromptMessageExtended]: """ Structured aggregation with telemetry. """ diff --git a/src/fast_agent/agents/workflow/router_agent.py b/src/fast_agent/agents/workflow/router_agent.py index 7d39dc869..6cdfe6057 100644 --- a/src/fast_agent/agents/workflow/router_agent.py +++ b/src/fast_agent/agents/workflow/router_agent.py @@ -5,7 +5,7 @@ by determining the best agent for a request and dispatching to it. """ -from typing import TYPE_CHECKING, List, Optional, Tuple, Type +from typing import TYPE_CHECKING, Type from mcp import Tool from opentelemetry import trace @@ -73,7 +73,7 @@ def agent_type(self) -> AgentType: def __init__( self, config: AgentConfig, - agents: List[LlmAgent], + agents: list[LlmAgent], routing_instruction: str | None = None, context: "Context | None" = None, default_request_params: RequestParams | None = None, @@ -143,7 +143,7 @@ async def shutdown(self) -> None: @staticmethod async def _generate_routing_instruction( - agents: List[LlmAgent], routing_instruction: Optional[str] = None + agents: list[LlmAgent], routing_instruction: str | None = None ) -> str: """ Generate the complete routing instruction with agent cards. @@ -184,9 +184,9 @@ async def attach_llm( async def generate_impl( self, - messages: List[PromptMessageExtended], - request_params: Optional[RequestParams] = None, - tools: List[Tool] | None = None, + messages: list[PromptMessageExtended], + request_params: RequestParams | None = None, + tools: list[Tool] | None = None, ) -> PromptMessageExtended: """ Route the request to the most appropriate agent and return its response. @@ -235,10 +235,10 @@ async def generate_impl( async def structured_impl( self, - messages: List[PromptMessageExtended], + messages: list[PromptMessageExtended], model: Type[ModelT], - request_params: Optional[RequestParams] = None, - ) -> Tuple[ModelT | None, PromptMessageExtended]: + request_params: RequestParams | None = None, + ) -> tuple[ModelT | None, PromptMessageExtended]: """ Route the request to the most appropriate agent and parse its response. @@ -285,7 +285,7 @@ async def structured_impl( async def _route_request( self, message: PromptMessageExtended - ) -> Tuple[RoutingResponse | None, str | None]: + ) -> tuple[RoutingResponse | None, str | None]: """ Determine which agent to route the request to. diff --git a/src/fast_agent/cli/commands/auth.py b/src/fast_agent/cli/commands/auth.py index 19d00510f..00c5b4bf2 100644 --- a/src/fast_agent/cli/commands/auth.py +++ b/src/fast_agent/cli/commands/auth.py @@ -5,8 +5,6 @@ from __future__ import annotations -from typing import Dict, List, Optional - import typer from rich.table import Table @@ -99,9 +97,9 @@ def _server_rows_from_settings(settings: Settings): return rows -def _servers_by_identity(settings: Settings) -> Dict[str, List[str]]: +def _servers_by_identity(settings: Settings) -> dict[str, list[str]]: """Group configured server names by derived identity (base URL).""" - mapping: Dict[str, List[str]] = {} + mapping: dict[str, list[str]] = {} mcp = getattr(settings, "mcp", None) servers = getattr(mcp, "servers", {}) if mcp else {} for name, cfg in servers.items(): @@ -115,8 +113,8 @@ def _servers_by_identity(settings: Settings) -> Dict[str, List[str]]: @app.command() def status( - target: Optional[str] = typer.Argument(None, help="Identity (base URL) or server name"), - config_path: Optional[str] = typer.Option(None, "--config-path", "-c"), + target: str | None = typer.Argument(None, help="Identity (base URL) or server name"), + config_path: str | None = typer.Option(None, "--config-path", "-c"), ) -> None: """Show keyring backend and token status for configured MCP servers.""" settings = get_settings(config_path) @@ -240,12 +238,12 @@ def status( @app.command() def clear( - server: Optional[str] = typer.Argument(None, help="Server name to clear (from config)"), - identity: Optional[str] = typer.Option( + server: str | None = typer.Argument(None, help="Server name to clear (from config)"), + identity: str | None = typer.Option( None, "--identity", help="Token identity (base URL) to clear" ), all: bool = typer.Option(False, "--all", help="Clear tokens for all identities in keyring"), - config_path: Optional[str] = typer.Option(None, "--config-path", "-c"), + config_path: str | None = typer.Option(None, "--config-path", "-c"), ) -> None: """Clear stored OAuth tokens from the keyring.""" targets_identities: list[str] = [] @@ -281,7 +279,7 @@ def clear( @app.callback(invoke_without_command=True) def main( - ctx: typer.Context, config_path: Optional[str] = typer.Option(None, "--config-path", "-c") + ctx: typer.Context, config_path: str | None = typer.Option(None, "--config-path", "-c") ) -> None: """Default to showing status if no subcommand is provided.""" if ctx.invoked_subcommand is None: @@ -293,13 +291,13 @@ def main( @app.command() def login( - target: Optional[str] = typer.Argument( + target: str | None = typer.Argument( None, help="Server name (from config) or identity (base URL)" ), - transport: Optional[str] = typer.Option( + transport: str | None = typer.Option( None, "--transport", help="Transport for identity mode: http or sse" ), - config_path: Optional[str] = typer.Option(None, "--config-path", "-c"), + config_path: str | None = typer.Option(None, "--config-path", "-c"), ) -> None: """Start OAuth flow and store tokens for a server. diff --git a/src/fast_agent/cli/commands/check_config.py b/src/fast_agent/cli/commands/check_config.py index 24218989b..955f8ef85 100644 --- a/src/fast_agent/cli/commands/check_config.py +++ b/src/fast_agent/cli/commands/check_config.py @@ -4,7 +4,6 @@ import sys from importlib.metadata import version from pathlib import Path -from typing import Optional import typer import yaml @@ -22,7 +21,7 @@ ) -def find_config_files(start_path: Path) -> dict[str, Optional[Path]]: +def find_config_files(start_path: Path) -> dict[str, Path | None]: """Find FastAgent configuration files, preferring secrets file next to config file.""" from fast_agent.config import find_fastagent_config_files @@ -43,7 +42,7 @@ def get_system_info() -> dict: } -def get_secrets_summary(secrets_path: Optional[Path]) -> dict: +def get_secrets_summary(secrets_path: Path | None) -> dict: """Extract information from the secrets file.""" result = { "status": "not_found", # Default status: not found @@ -143,7 +142,7 @@ def get_fastagent_version() -> str: return "unknown" -def get_config_summary(config_path: Optional[Path]) -> dict: +def get_config_summary(config_path: Path | None) -> dict: """Extract key information from the configuration file.""" from fast_agent.config import MCPTimelineSettings, Settings @@ -727,7 +726,7 @@ def _truncate(text: str, length: int = 70) -> str: @app.command() def show( - path: Optional[str] = typer.Argument(None, help="Path to configuration file to display"), + path: str | None = typer.Argument(None, help="Path to configuration file to display"), secrets: bool = typer.Option( False, "--secrets", "-s", help="Show secrets file instead of config" ), diff --git a/src/fast_agent/cli/commands/server_helpers.py b/src/fast_agent/cli/commands/server_helpers.py index 3f0af9f9e..14e6f09ce 100644 --- a/src/fast_agent/cli/commands/server_helpers.py +++ b/src/fast_agent/cli/commands/server_helpers.py @@ -1,6 +1,6 @@ """Helper functions for server configuration and naming.""" -from typing import Any, Dict +from typing import Any def generate_server_name(identifier: str) -> str: @@ -59,7 +59,7 @@ def generate_server_name(identifier: str) -> str: return server_name -async def add_servers_to_config(fast_app: Any, servers: Dict[str, Dict[str, Any]]) -> None: +async def add_servers_to_config(fast_app: Any, servers: dict[str, dict[str, Any]]) -> None: """Add server configurations to the FastAgent app config. This function handles the repetitive initialization and configuration diff --git a/src/fast_agent/cli/commands/url_parser.py b/src/fast_agent/cli/commands/url_parser.py index b6e3e3f4a..337da45b9 100644 --- a/src/fast_agent/cli/commands/url_parser.py +++ b/src/fast_agent/cli/commands/url_parser.py @@ -5,7 +5,7 @@ import hashlib import re -from typing import Dict, List, Literal, Tuple +from typing import Literal from urllib.parse import urlparse from fast_agent.mcp.hf_auth import add_hf_auth_header @@ -13,7 +13,7 @@ def parse_server_url( url: str, -) -> Tuple[str, Literal["http", "sse"], str]: +) -> tuple[str, Literal["http", "sse"], str]: """ Parse a server URL and determine the transport type and server name. @@ -103,7 +103,7 @@ def generate_server_name(url: str) -> str: def parse_server_urls( urls_param: str, auth_token: str | None = None -) -> List[Tuple[str, Literal["http", "sse"], str, Dict[str, str] | None]]: +) -> list[tuple[str, Literal["http", "sse"], str, dict[str, str] | None]]: """ Parse a comma-separated list of URLs into server configurations. @@ -142,8 +142,8 @@ def parse_server_urls( def generate_server_configs( - parsed_urls: List[Tuple[str, Literal["http", "sse"], str, Dict[str, str] | None]], -) -> Dict[str, Dict[str, str | Dict[str, str]]]: + parsed_urls: list[tuple[str, Literal["http", "sse"], str, dict[str, str] | None]], +) -> dict[str, dict[str, str | dict[str, str]]]: """ Generate server configurations from parsed URLs. @@ -153,7 +153,7 @@ def generate_server_configs( Returns: Dictionary of server configurations """ - server_configs: Dict[str, Dict[str, str | Dict[str, str]]] = {} + server_configs: dict[str, dict[str, str | dict[str, str]]] = {} # Keep track of server name occurrences to handle collisions name_counts = {} @@ -176,7 +176,7 @@ def generate_server_configs( final_name = f"{server_name}_{suffix}" name_counts[server_name] += 1 - config: Dict[str, str | Dict[str, str]] = { + config: dict[str, str | dict[str, str]] = { "transport": transport_type, "url": url, } diff --git a/src/fast_agent/config.py b/src/fast_agent/config.py index 8d85261ae..f5651b1a7 100644 --- a/src/fast_agent/config.py +++ b/src/fast_agent/config.py @@ -6,7 +6,7 @@ import os import re from pathlib import Path -from typing import Any, Dict, List, Literal, Optional, Tuple +from typing import Any, Literal from mcp import Implementation from pydantic import BaseModel, ConfigDict, field_validator, model_validator @@ -138,10 +138,10 @@ class MCPRootSettings(BaseModel): uri: str """The URI identifying the root. Must start with file://""" - name: Optional[str] = None + name: str | None = None """Optional name for the root.""" - server_uri_alias: Optional[str] = None + server_uri_alias: str | None = None """Optional URI alias for presentation to the server""" @field_validator("uri", "server_uri_alias") @@ -172,7 +172,7 @@ class MCPServerSettings(BaseModel): command: str | None = None """The command to execute the server (e.g. npx).""" - args: List[str] | None = None + args: list[str] | None = None """The arguments for the server command.""" read_timeout_seconds: int | None = None @@ -184,16 +184,16 @@ class MCPServerSettings(BaseModel): url: str | None = None """The URL for the server (e.g. for SSE/SHTTP transport).""" - headers: Dict[str, str] | None = None + headers: dict[str, str] | None = None """Headers dictionary for HTTP connections""" auth: MCPServerAuthSettings | None = None """The authentication configuration for the server.""" - roots: Optional[List[MCPRootSettings]] = None + roots: list[MCPRootSettings] | None = None """Root directories this server has access to.""" - env: Dict[str, str] | None = None + env: dict[str, str] | None = None """Environment variables to pass to the server process.""" sampling: MCPSamplingSettings | None = None @@ -250,7 +250,7 @@ def validate_transport_inference(cls, values): class MCPSettings(BaseModel): """Configuration for all MCP servers.""" - servers: Dict[str, MCPServerSettings] = {} + servers: dict[str, MCPServerSettings] = {} model_config = ConfigDict(extra="allow", arbitrary_types_allowed=True) @@ -478,7 +478,7 @@ class LoggerSettings(BaseModel): """Streaming renderer for assistant responses""" -def find_fastagent_config_files(start_path: Path) -> Tuple[Optional[Path], Optional[Path]]: +def find_fastagent_config_files(start_path: Path) -> tuple[Path | None, Path | None]: """ Find FastAgent configuration files with standardized behavior. @@ -581,7 +581,7 @@ class Settings(BaseSettings): generic: GenericSettings | None = None """Settings for using Generic models in the fast-agent application""" - tensorzero: Optional[TensorZeroSettings] = None + tensorzero: TensorZeroSettings | None = None """Settings for using TensorZero inference gateway""" azure: AzureSettings | None = None diff --git a/src/fast_agent/context.py b/src/fast_agent/context.py index 568268e2e..2013bd26a 100644 --- a/src/fast_agent/context.py +++ b/src/fast_agent/context.py @@ -5,7 +5,7 @@ import logging import uuid from pathlib import Path -from typing import TYPE_CHECKING, Any, Optional, Union +from typing import TYPE_CHECKING, Any, Union from opentelemetry import trace from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter @@ -49,15 +49,15 @@ class Context(BaseModel): This is a global context that is shared across the application. """ - config: Optional[Settings] = None - executor: Optional[Executor] = None - human_input_handler: Optional[Any] = None - signal_notification: Optional[SignalWaitCallback] = None + config: Settings | None = None + executor: Executor | None = None + human_input_handler: Any | None = None + signal_notification: SignalWaitCallback | None = None # Registries - server_registry: Optional[ServerRegistry] = None - task_registry: Optional[ActivityRegistry] = None - skill_registry: Optional[SkillRegistry] = None + server_registry: ServerRegistry | None = None + task_registry: ActivityRegistry | None = None + skill_registry: SkillRegistry | None = None tracer: trace.Tracer | None = None _connection_manager: "MCPConnectionManager | None" = None @@ -184,7 +184,7 @@ async def configure_executor(config: "Settings"): async def initialize_context( - config: Optional[Union["Settings", str]] = None, store_globally: bool = False + config: Union["Settings", str] | None = None, store_globally: bool = False ): """ Initialize the global application context. diff --git a/src/fast_agent/core/agent_app.py b/src/fast_agent/core/agent_app.py index b141ac204..6b25a3a9a 100644 --- a/src/fast_agent/core/agent_app.py +++ b/src/fast_agent/core/agent_app.py @@ -2,7 +2,7 @@ Direct AgentApp implementation for interacting with agents without proxies. """ -from typing import Dict, List, Mapping, Optional, Union +from typing import Mapping, Union from deprecated import deprecated from mcp.types import GetPromptResult, PromptMessage @@ -27,7 +27,7 @@ class AgentApp: calls to the default agent (the first agent in the container). """ - def __init__(self, agents: Dict[str, AgentProtocol]) -> None: + def __init__(self, agents: dict[str, AgentProtocol]) -> None: """ Initialize the DirectAgentApp. @@ -83,7 +83,7 @@ async def __call__( async def send( self, message: Union[str, PromptMessage, PromptMessageExtended], - agent_name: Optional[str] = None, + agent_name: str | None = None, request_params: RequestParams | None = None, ) -> str: """ @@ -117,7 +117,7 @@ def _agent(self, agent_name: str | None) -> AgentProtocol: async def apply_prompt( self, prompt: Union[str, GetPromptResult], - arguments: Dict[str, str] | None = None, + arguments: dict[str, str] | None = None, agent_name: str | None = None, as_template: bool = False, ) -> str: @@ -159,7 +159,7 @@ async def list_prompts(self, namespace: str | None = None, agent_name: str | Non async def get_prompt( self, prompt_name: str, - arguments: Dict[str, str] | None = None, + arguments: dict[str, str] | None = None, server_name: str | None = None, agent_name: str | None = None, ): @@ -206,7 +206,7 @@ async def list_resources( self, server_name: str | None = None, agent_name: str | None = None, - ) -> Mapping[str, List[str]]: + ) -> Mapping[str, list[str]]: """ List available resources from one or all servers. @@ -393,7 +393,7 @@ def _show_parallel_agent_usage(self, parallel_agent) -> None: f"[dim] {prefix} {usage_data['name']}: {usage_data['display_text']}[/dim]{usage_data['cache_suffix']}" ) - def _format_agent_usage(self, agent) -> Optional[Dict]: + def _format_agent_usage(self, agent) -> dict | None: """Format usage information for a single agent.""" if not agent or not agent.usage_accumulator: return None diff --git a/src/fast_agent/core/core_app.py b/src/fast_agent/core/core_app.py index f16cab6d9..0bbae1240 100644 --- a/src/fast_agent/core/core_app.py +++ b/src/fast_agent/core/core_app.py @@ -2,7 +2,7 @@ import asyncio from contextlib import asynccontextmanager -from typing import TYPE_CHECKING, Optional, TypeVar +from typing import TYPE_CHECKING, TypeVar from fast_agent.core.logging.logger import get_logger from fast_agent.event_progress import ProgressAction @@ -24,8 +24,8 @@ class Core: def __init__( self, name: str = "fast-agent", - settings: Optional[Settings] | str = None, - signal_notification: Optional[SignalWaitCallback] = None, + settings: Settings | None | str = None, + signal_notification: SignalWaitCallback | None = None, ) -> None: """ Initialize the core. @@ -43,7 +43,7 @@ def __init__( self._logger = None # Use forward reference for type to avoid runtime import - self._context: Optional["Context"] = None + self._context: "Context" | None = None self._initialized = False @property diff --git a/src/fast_agent/core/direct_decorators.py b/src/fast_agent/core/direct_decorators.py index 87fd4be17..60e412649 100644 --- a/src/fast_agent/core/direct_decorators.py +++ b/src/fast_agent/core/direct_decorators.py @@ -10,10 +10,7 @@ Any, Awaitable, Callable, - Dict, - List, Literal, - Optional, ParamSpec, Protocol, TypeVar, @@ -49,7 +46,7 @@ def __call__(self, *args: P.args, **kwargs: P.kwargs) -> Awaitable[R]: ... class DecoratedOrchestratorProtocol(DecoratedAgentProtocol[P, R], Protocol): """Protocol for decorated orchestrator functions with additional metadata.""" - _child_agents: List[str] + _child_agents: list[str] _plan_type: Literal["full", "iterative"] @@ -57,21 +54,21 @@ class DecoratedOrchestratorProtocol(DecoratedAgentProtocol[P, R], Protocol): class DecoratedRouterProtocol(DecoratedAgentProtocol[P, R], Protocol): """Protocol for decorated router functions with additional metadata.""" - _router_agents: List[str] + _router_agents: list[str] # Protocol for chain functions class DecoratedChainProtocol(DecoratedAgentProtocol[P, R], Protocol): """Protocol for decorated chain functions with additional metadata.""" - _chain_agents: List[str] + _chain_agents: list[str] # Protocol for parallel functions class DecoratedParallelProtocol(DecoratedAgentProtocol[P, R], Protocol): """Protocol for decorated parallel functions with additional metadata.""" - _fan_out: List[str] + _fan_out: list[str] _fan_in: str @@ -177,20 +174,20 @@ def _decorator_impl( name: str, instruction: str, *, - servers: List[str] = [], - model: Optional[str] = None, + servers: list[str] = [], + model: str | None = None, use_history: bool = True, request_params: RequestParams | None = None, human_input: bool = False, default: bool = False, - tools: Optional[Dict[str, List[str]]] = None, - resources: Optional[Dict[str, List[str]]] = None, - prompts: Optional[Dict[str, List[str]]] = None, + tools: dict[str, list[str]] | None = None, + resources: dict[str, list[str]] | None = None, + prompts: dict[str, list[str]] | None = None, skills: SkillManifest | SkillRegistry | Path | str - | List[SkillManifest | SkillRegistry | Path | str | None] + | list[SkillManifest | SkillRegistry | Path | str | None] | None = None, **extra_kwargs, ) -> Callable[[Callable[P, Coroutine[Any, Any, R]]], Callable[P, Coroutine[Any, Any, R]]]: @@ -260,20 +257,20 @@ def decorator(func: Callable[P, Coroutine[Any, Any, R]]) -> Callable[P, Coroutin def agent( self, name: str = "default", - instruction_or_kwarg: Optional[str | Path | AnyUrl] = None, + instruction_or_kwarg: str | Path | AnyUrl | None = None, *, instruction: str | Path | AnyUrl = "You are a helpful agent.", - servers: List[str] = [], - tools: Optional[Dict[str, List[str]]] = None, - resources: Optional[Dict[str, List[str]]] = None, - prompts: Optional[Dict[str, List[str]]] = None, + servers: list[str] = [], + tools: dict[str, list[str]] | None = None, + resources: dict[str, list[str]] | None = None, + prompts: dict[str, list[str]] | None = None, skills: SkillManifest | SkillRegistry | Path | str | None = None, - model: Optional[str] = None, + model: str | None = None, use_history: bool = True, request_params: RequestParams | None = None, human_input: bool = False, default: bool = False, - elicitation_handler: Optional[ElicitationFnT] = None, + elicitation_handler: ElicitationFnT | None = None, api_key: str | None = None, ) -> Callable[[Callable[P, Coroutine[Any, Any, R]]], Callable[P, Coroutine[Any, Any, R]]]: """ @@ -327,20 +324,20 @@ def custom( self, cls, name: str = "default", - instruction_or_kwarg: Optional[str | Path | AnyUrl] = None, + instruction_or_kwarg: str | Path | AnyUrl | None = None, *, instruction: str | Path | AnyUrl = "You are a helpful agent.", - servers: List[str] = [], - tools: Optional[Dict[str, List[str]]] = None, - resources: Optional[Dict[str, List[str]]] = None, - prompts: Optional[Dict[str, List[str]]] = None, + servers: list[str] = [], + tools: dict[str, list[str]] | None = None, + resources: dict[str, list[str]] | None = None, + prompts: dict[str, list[str]] | None = None, skills: SkillManifest | SkillRegistry | Path | str | None = None, - model: Optional[str] = None, + model: str | None = None, use_history: bool = True, request_params: RequestParams | None = None, human_input: bool = False, default: bool = False, - elicitation_handler: Optional[ElicitationFnT] = None, + elicitation_handler: ElicitationFnT | None = None, api_key: str | None = None, ) -> Callable[[Callable[P, Coroutine[Any, Any, R]]], Callable[P, Coroutine[Any, Any, R]]]: """ @@ -397,9 +394,9 @@ def orchestrator( self, name: str, *, - agents: List[str], + agents: list[str], instruction: str | Path | AnyUrl = DEFAULT_INSTRUCTION_ORCHESTRATOR, - model: Optional[str] = None, + model: str | None = None, request_params: RequestParams | None = None, use_history: bool = False, human_input: bool = False, @@ -452,9 +449,9 @@ def iterative_planner( self, name: str, *, - agents: List[str], + agents: list[str], instruction: str | Path | AnyUrl = ITERATIVE_PLAN_SYSTEM_PROMPT_TEMPLATE, - model: Optional[str] = None, + model: str | None = None, request_params: RequestParams | None = None, plan_iterations: int = -1, default: bool = False, @@ -502,20 +499,20 @@ def router( self, name: str, *, - agents: List[str], - instruction: Optional[str | Path | AnyUrl] = None, - servers: List[str] = [], - tools: Optional[Dict[str, List[str]]] = None, - resources: Optional[Dict[str, List[str]]] = None, - prompts: Optional[Dict[str, List[str]]] = None, - model: Optional[str] = None, + agents: list[str], + instruction: str | Path | AnyUrl | None = None, + servers: list[str] = [], + tools: dict[str, list[str]] | None = None, + resources: dict[str, list[str]] | None = None, + prompts: dict[str, list[str]] | None = None, + model: str | None = None, use_history: bool = False, request_params: RequestParams | None = None, human_input: bool = False, default: bool = False, - elicitation_handler: Optional[ + elicitation_handler: ElicitationFnT - ] = None, ## exclude from docs, decide whether allowable + | None = None, ## exclude from docs, decide whether allowable api_key: str | None = None, ) -> Callable[[Callable[P, Coroutine[Any, Any, R]]], Callable[P, Coroutine[Any, Any, R]]]: """ @@ -561,8 +558,8 @@ def chain( self, name: str, *, - sequence: List[str], - instruction: Optional[str | Path | AnyUrl] = None, + sequence: list[str], + instruction: str | Path | AnyUrl | None = None, cumulative: bool = False, default: bool = False, ) -> Callable[[Callable[P, Coroutine[Any, Any, R]]], Callable[P, Coroutine[Any, Any, R]]]: @@ -606,9 +603,9 @@ def parallel( self, name: str, *, - fan_out: List[str], + fan_out: list[str], fan_in: str | None = None, - instruction: Optional[str | Path | AnyUrl] = None, + instruction: str | Path | AnyUrl | None = None, include_request: bool = True, default: bool = False, ) -> Callable[[Callable[P, Coroutine[Any, Any, R]]], Callable[P, Coroutine[Any, Any, R]]]: @@ -651,7 +648,7 @@ def evaluator_optimizer( *, generator: str, evaluator: str, - instruction: Optional[str | Path | AnyUrl] = None, + instruction: str | Path | AnyUrl | None = None, min_rating: str = "GOOD", max_refinements: int = 3, default: bool = False, diff --git a/src/fast_agent/core/direct_factory.py b/src/fast_agent/core/direct_factory.py index 68df79240..bc04e0173 100644 --- a/src/fast_agent/core/direct_factory.py +++ b/src/fast_agent/core/direct_factory.py @@ -4,7 +4,7 @@ """ from functools import partial -from typing import Any, Dict, List, Optional, Protocol, TypeVar +from typing import Any, Protocol, TypeVar from fast_agent.agents import McpAgent from fast_agent.agents.agent_types import AgentConfig, AgentType @@ -31,8 +31,8 @@ from fast_agent.types import RequestParams # Type aliases for improved readability and IDE support -AgentDict = Dict[str, AgentProtocol] -AgentConfigDict = Dict[str, Dict[str, Any]] +AgentDict = dict[str, AgentProtocol] +AgentConfigDict = dict[str, dict[str, Any]] T = TypeVar("T") # For generic types @@ -75,18 +75,18 @@ async def __call__( app_instance: Core, agents_dict: AgentConfigDict, agent_type: AgentType, - active_agents: Optional[AgentDict] = None, - model_factory_func: Optional[ModelFactoryFunctionProtocol] = None, + active_agents: AgentDict | None = None, + model_factory_func: ModelFactoryFunctionProtocol | None = None, **kwargs: Any, ) -> AgentDict: ... def get_model_factory( context, - model: Optional[str] = None, - request_params: Optional[RequestParams] = None, - default_model: Optional[str] = None, - cli_model: Optional[str] = None, + model: str | None = None, + request_params: RequestParams | None = None, + default_model: str | None = None, + cli_model: str | None = None, ) -> LLMFactoryProtocol: """ Get model factory using specified or default model. @@ -128,7 +128,7 @@ async def create_agents_by_type( agents_dict: AgentConfigDict, agent_type: AgentType, model_factory_func: ModelFactoryFunctionProtocol, - active_agents: Optional[AgentDict] = None, + active_agents: AgentDict | None = None, **kwargs: Any, ) -> AgentDict: """ @@ -396,7 +396,7 @@ async def active_agents_in_dependency_group( app_instance: Core, agents_dict: AgentConfigDict, model_factory_func: ModelFactoryFunctionProtocol, - group: List[str], + group: list[str], active_agents: AgentDict, ): """ diff --git a/src/fast_agent/core/executor/executor.py b/src/fast_agent/core/executor/executor.py index 300f2625d..bf6842220 100644 --- a/src/fast_agent/core/executor/executor.py +++ b/src/fast_agent/core/executor/executor.py @@ -10,11 +10,9 @@ AsyncIterator, Callable, Coroutine, - Dict, - List, - Optional, Type, TypeVar, + Union, ) from pydantic import BaseModel, ConfigDict @@ -42,7 +40,7 @@ class ExecutorConfig(BaseModel): max_concurrent_activities: int | None = None # Unbounded by default timeout_seconds: timedelta | None = None # No timeout by default - retry_policy: Dict[str, Any] | None = None + retry_policy: dict[str, Any] | None = None model_config = ConfigDict(extra="allow", arbitrary_types_allowed=True) @@ -55,7 +53,7 @@ def __init__( engine: str, config: ExecutorConfig | None = None, signal_bus: SignalHandler = None, - context: Optional["Context"] = None, + context: Union["Context", None] = None, **kwargs, ) -> None: super().__init__(context=context, **kwargs) @@ -84,13 +82,13 @@ async def execute( self, *tasks: Callable[..., R] | Coroutine[Any, Any, R], **kwargs: Any, - ) -> List[R | BaseException]: + ) -> list[R | BaseException]: """Execute a list of tasks and return their results""" @abstractmethod async def execute_streaming( self, - *tasks: List[Callable[..., R] | Coroutine[Any, Any, R]], + *tasks: list[Callable[..., R] | Coroutine[Any, Any, R]], **kwargs: Any, ) -> AsyncIterator[R | BaseException]: """Execute tasks and yield results as they complete""" @@ -98,13 +96,13 @@ async def execute_streaming( async def map( self, func: Callable[..., R], - inputs: List[Any], + inputs: list[Any], **kwargs: Any, - ) -> List[R | BaseException]: + ) -> list[R | BaseException]: """ Run `func(item)` for each item in `inputs` with concurrency limit. """ - results: List[R, BaseException] = [] + results: list[R, BaseException] = [] async def run(item): if self.config.max_concurrent_activities: @@ -233,7 +231,7 @@ async def execute( self, *tasks: Callable[..., R] | Coroutine[Any, Any, R], **kwargs: Any, - ) -> List[R | BaseException]: + ) -> list[R | BaseException]: return await asyncio.gather( *(self._execute_task(task, **kwargs) for task in tasks), return_exceptions=True, @@ -241,7 +239,7 @@ async def execute( async def execute_streaming( self, - *tasks: List[Callable[..., R] | Coroutine[Any, Any, R]], + *tasks: list[Callable[..., R] | Coroutine[Any, Any, R]], **kwargs: Any, ) -> AsyncIterator[R | BaseException]: # TODO: saqadri - validate if async with self.execution_context() is needed here diff --git a/src/fast_agent/core/executor/task_registry.py b/src/fast_agent/core/executor/task_registry.py index 78b126af4..3ea76b1e6 100644 --- a/src/fast_agent/core/executor/task_registry.py +++ b/src/fast_agent/core/executor/task_registry.py @@ -4,17 +4,17 @@ The user just writes standard functions annotated with @workflow_task, but behind the scenes a workflow graph is built. """ -from typing import Any, Callable, Dict, List +from typing import Any, Callable class ActivityRegistry: """Centralized task/activity management with validation and metadata.""" def __init__(self) -> None: - self._activities: Dict[str, Callable] = {} - self._metadata: Dict[str, Dict[str, Any]] = {} + self._activities: dict[str, Callable] = {} + self._metadata: dict[str, dict[str, Any]] = {} - def register(self, name: str, func: Callable, metadata: Dict[str, Any] | None = None) -> None: + def register(self, name: str, func: Callable, metadata: dict[str, Any] | None = None) -> None: if name in self._activities: raise ValueError(f"Activity '{name}' is already registered.") self._activities[name] = func @@ -25,8 +25,8 @@ def get_activity(self, name: str) -> Callable: raise KeyError(f"Activity '{name}' not found.") return self._activities[name] - def get_metadata(self, name: str) -> Dict[str, Any]: + def get_metadata(self, name: str) -> dict[str, Any]: return self._metadata.get(name, {}) - def list_activities(self) -> List[str]: + def list_activities(self) -> list[str]: return list(self._activities.keys()) diff --git a/src/fast_agent/core/executor/workflow_signal.py b/src/fast_agent/core/executor/workflow_signal.py index f4a7e5ee5..c7fb4cf14 100644 --- a/src/fast_agent/core/executor/workflow_signal.py +++ b/src/fast_agent/core/executor/workflow_signal.py @@ -1,7 +1,7 @@ import asyncio import uuid from abc import ABC, abstractmethod -from typing import Any, Callable, Dict, Generic, List, Protocol, TypeVar +from typing import Any, Callable, Generic, Protocol, TypeVar from pydantic import BaseModel, ConfigDict @@ -14,7 +14,7 @@ class Signal(BaseModel, Generic[SignalValueT]): name: str description: str | None = "Workflow Signal" payload: SignalValueT | None = None - metadata: Dict[str, Any] | None = None + metadata: dict[str, Any] | None = None workflow_id: str | None = None model_config = ConfigDict(arbitrary_types_allowed=True) @@ -71,9 +71,9 @@ class BaseSignalHandler(ABC, Generic[SignalValueT]): def __init__(self) -> None: # Map signal_name -> list of PendingSignal objects - self._pending_signals: Dict[str, List[PendingSignal]] = {} + self._pending_signals: dict[str, list[PendingSignal]] = {} # Map signal_name -> list of (unique_name, handler) tuples - self._handlers: Dict[str, List[tuple[str, Callable]]] = {} + self._handlers: dict[str, list[tuple[str, Callable]]] = {} self._lock = asyncio.Lock() async def cleanup(self, signal_name: str | None = None) -> None: @@ -132,8 +132,8 @@ class ConsoleSignalHandler(SignalHandler[str]): """Simple console-based signal handling (blocks on input).""" def __init__(self) -> None: - self._pending_signals: Dict[str, List[PendingSignal]] = {} - self._handlers: Dict[str, List[Callable]] = {} + self._pending_signals: dict[str, list[PendingSignal]] = {} + self._handlers: dict[str, list[Callable]] = {} async def wait_for_signal(self, signal, timeout_seconds=None): """Block and wait for console input.""" @@ -275,7 +275,7 @@ class LocalSignalStore: def __init__(self) -> None: # For each signal_name, store a list of futures that are waiting for it - self._waiters: Dict[str, List[asyncio.Future]] = {} + self._waiters: dict[str, list[asyncio.Future]] = {} async def emit(self, signal_name: str, payload: Any) -> None: # If we have waiting futures, set their result @@ -311,7 +311,7 @@ async def __call__( signal_name: str, request_id: str | None = None, workflow_id: str | None = None, - metadata: Dict[str, Any] | None = None, + metadata: dict[str, Any] | None = None, ) -> None: """ Receive a notification that a workflow is pausing on a signal. diff --git a/src/fast_agent/core/fastagent.py b/src/fast_agent/core/fastagent.py index 9785a5b01..f1ec4b78c 100644 --- a/src/fast_agent/core/fastagent.py +++ b/src/fast_agent/core/fastagent.py @@ -21,10 +21,7 @@ AsyncIterator, Awaitable, Callable, - Dict, - List, Literal, - Optional, ParamSpec, TypeVar, ) @@ -135,10 +132,10 @@ def __init__( self._skills_directory_override = ( Path(skills_directory).expanduser() if skills_directory else None ) - self._default_skill_manifests: List[SkillManifest] = [] + self._default_skill_manifests: list[SkillManifest] = [] self._server_instance_factory = None self._server_instance_dispose = None - self._server_managed_instances: List[AgentInstance] = [] + self._server_managed_instances: list[AgentInstance] = [] # --- Wrap argument parsing logic --- if parse_cli_args: @@ -281,7 +278,7 @@ def __init__( raise SystemExit(1) # Dictionary to store agent configurations from decorators - self.agents: Dict[str, Dict[str, Any]] = {} + self.agents: dict[str, dict[str, Any]] = {} def _load_config(self) -> None: """Load configuration from YAML file including secrets using get_settings @@ -324,20 +321,20 @@ def context(self) -> Context: def agent( self, name: str = "default", - instruction_or_kwarg: Optional[str | Path | AnyUrl] = None, + instruction_or_kwarg: str | Path | AnyUrl | None = None, *, instruction: str | Path | AnyUrl = DEFAULT_AGENT_INSTRUCTION, - servers: List[str] = [], - tools: Optional[Dict[str, List[str]]] = None, - resources: Optional[Dict[str, List[str]]] = None, - prompts: Optional[Dict[str, List[str]]] = None, - skills: Optional[List[SkillManifest | SkillRegistry | Path | str | None]] = None, - model: Optional[str] = None, + servers: list[str] = [], + tools: dict[str, list[str]] | None = None, + resources: dict[str, list[str]] | None = None, + prompts: dict[str, list[str]] | None = None, + skills: list[SkillManifest | SkillRegistry | Path | str | None] | None = None, + model: str | None = None, use_history: bool = True, request_params: RequestParams | None = None, human_input: bool = False, default: bool = False, - elicitation_handler: Optional[ElicitationFnT] = None, + elicitation_handler: ElicitationFnT | None = None, api_key: str | None = None, ) -> Callable[ [Callable[P, Coroutine[Any, Any, R]]], Callable[P, Coroutine[Any, Any, R]] @@ -347,19 +344,19 @@ def custom( self, cls, name: str = "default", - instruction_or_kwarg: Optional[str | Path | AnyUrl] = None, + instruction_or_kwarg: str | Path | AnyUrl | None = None, *, instruction: str | Path | AnyUrl = "You are a helpful agent.", - servers: List[str] = [], - tools: Optional[Dict[str, List[str]]] = None, - resources: Optional[Dict[str, List[str]]] = None, - prompts: Optional[Dict[str, List[str]]] = None, - model: Optional[str] = None, + servers: list[str] = [], + tools: dict[str, list[str]] | None = None, + resources: dict[str, list[str]] | None = None, + prompts: dict[str, list[str]] | None = None, + model: str | None = None, use_history: bool = True, request_params: RequestParams | None = None, human_input: bool = False, default: bool = False, - elicitation_handler: Optional[ElicitationFnT] = None, + elicitation_handler: ElicitationFnT | None = None, api_key: str | None = None, ) -> Callable[[Callable[P, Awaitable[R]]], Callable[P, Awaitable[R]]]: ... @@ -367,11 +364,11 @@ def orchestrator( self, name: str, *, - agents: List[str], + agents: list[str], instruction: str | Path | AnyUrl = "You are an expert planner. Given an objective task and a list of Agents\n(which are collections of capabilities), your job is to break down the objective\ninto a series of steps, which can be performed by these agents.\n", - model: Optional[str] = None, + model: str | None = None, request_params: RequestParams | None = None, use_history: bool = False, human_input: bool = False, @@ -385,9 +382,9 @@ def iterative_planner( self, name: str, *, - agents: List[str], + agents: list[str], instruction: str | Path | AnyUrl = "You are an expert planner. Plan iteratively.", - model: Optional[str] = None, + model: str | None = None, request_params: RequestParams | None = None, plan_iterations: int = -1, default: bool = False, @@ -398,18 +395,18 @@ def router( self, name: str, *, - agents: List[str], - instruction: Optional[str | Path | AnyUrl] = None, - servers: List[str] = [], - tools: Optional[Dict[str, List[str]]] = None, - resources: Optional[Dict[str, List[str]]] = None, - prompts: Optional[Dict[str, List[str]]] = None, - model: Optional[str] = None, + agents: list[str], + instruction: str | Path | AnyUrl | None = None, + servers: list[str] = [], + tools: dict[str, list[str]] | None = None, + resources: dict[str, list[str]] | None = None, + prompts: dict[str, list[str]] | None = None, + model: str | None = None, use_history: bool = False, request_params: RequestParams | None = None, human_input: bool = False, default: bool = False, - elicitation_handler: Optional[ElicitationFnT] = None, + elicitation_handler: ElicitationFnT | None = None, api_key: str | None = None, ) -> Callable[[Callable[P, Awaitable[R]]], Callable[P, Awaitable[R]]]: ... @@ -417,8 +414,8 @@ def chain( self, name: str, *, - sequence: List[str], - instruction: Optional[str | Path | AnyUrl] = None, + sequence: list[str], + instruction: str | Path | AnyUrl | None = None, cumulative: bool = False, default: bool = False, ) -> Callable[[Callable[P, Awaitable[R]]], Callable[P, Awaitable[R]]]: ... @@ -427,9 +424,9 @@ def parallel( self, name: str, *, - fan_out: List[str], + fan_out: list[str], fan_in: str | None = None, - instruction: Optional[str | Path | AnyUrl] = None, + instruction: str | Path | AnyUrl | None = None, include_request: bool = True, default: bool = False, ) -> Callable[[Callable[P, Awaitable[R]]], Callable[P, Awaitable[R]]]: ... @@ -440,7 +437,7 @@ def evaluator_optimizer( *, generator: str, evaluator: str, - instruction: Optional[str | Path | AnyUrl] = None, + instruction: str | Path | AnyUrl | None = None, min_rating: str = "GOOD", max_refinements: int = 3, default: bool = False, @@ -477,7 +474,7 @@ async def run(self) -> AsyncIterator["AgentApp"]: Context manager for running the application. Initializes all registered agents. """ - active_agents: Dict[str, AgentProtocol] = {} + active_agents: dict[str, AgentProtocol] = {} had_error = False await self.app.initialize() @@ -503,7 +500,7 @@ async def run(self) -> AsyncIterator["AgentApp"]: self.context.skill_registry = override_registry registry = override_registry - default_skills: List[SkillManifest] = [] + default_skills: list[SkillManifest] = [] if registry: default_skills = registry.load_manifests() @@ -733,7 +730,7 @@ async def dispose_agent_instance(instance: AgentInstance) -> None: if hasattr(self.args, "prompt_file") and self.args.prompt_file: agent_name = self.args.agent - prompt: List[PromptMessageExtended] = load_prompt( + prompt: list[PromptMessageExtended] = load_prompt( Path(self.args.prompt_file) ) if agent_name not in active_agents: @@ -845,7 +842,7 @@ def _apply_instruction_context( if request_params is not None: request_params.systemPrompt = resolved - def _apply_skills_to_agent_configs(self, default_skills: List[SkillManifest]) -> None: + def _apply_skills_to_agent_configs(self, default_skills: list[SkillManifest]) -> None: self._default_skill_manifests = list(default_skills) for agent_data in self.agents.values(): @@ -867,13 +864,13 @@ def _resolve_skills( | SkillRegistry | Path | str - | List[SkillManifest | SkillRegistry | Path | str | None] + | list[SkillManifest | SkillRegistry | Path | str | None] | None, - ) -> List[SkillManifest]: + ) -> list[SkillManifest]: if entry is None: return [] if isinstance(entry, list): - manifests: List[SkillManifest] = [] + manifests: list[SkillManifest] = [] for item in entry: manifests.extend(self._resolve_skills(item)) return manifests @@ -900,15 +897,15 @@ def _resolve_skills( return [] @staticmethod - def _deduplicate_skills(manifests: List[SkillManifest]) -> List[SkillManifest]: - unique: Dict[str, SkillManifest] = {} + def _deduplicate_skills(manifests: list[SkillManifest]) -> list[SkillManifest]: + unique: dict[str, SkillManifest] = {} for manifest in manifests: key = manifest.name.lower() if key not in unique: unique[key] = manifest return list(unique.values()) - def _handle_error(self, e: Exception, error_type: Optional[str] = None) -> None: + def _handle_error(self, e: Exception, error_type: str | None = None) -> None: """ Handle errors with consistent formatting and messaging. @@ -975,9 +972,9 @@ async def start_server( transport: str = "http", host: str = "0.0.0.0", port: int = 8000, - server_name: Optional[str] = None, - server_description: Optional[str] = None, - tool_description: Optional[str] = None, + server_name: str | None = None, + server_description: str | None = None, + tool_description: str | None = None, instance_scope: str = "shared", ) -> None: """ @@ -1038,9 +1035,9 @@ async def run_with_mcp_server( transport: str = "sse", host: str = "0.0.0.0", port: int = 8000, - server_name: Optional[str] = None, - server_description: Optional[str] = None, - tool_description: Optional[str] = None, + server_name: str | None = None, + server_description: str | None = None, + tool_description: str | None = None, instance_scope: str = "shared", ) -> None: """ @@ -1100,7 +1097,7 @@ async def app_main(): @dataclass class AgentInstance: app: AgentApp - agents: Dict[str, "AgentProtocol"] + agents: dict[str, "AgentProtocol"] async def shutdown(self) -> None: for agent in self.agents.values(): diff --git a/src/fast_agent/core/logging/events.py b/src/fast_agent/core/logging/events.py index faa84e0e8..ca173c632 100644 --- a/src/fast_agent/core/logging/events.py +++ b/src/fast_agent/core/logging/events.py @@ -5,7 +5,7 @@ import logging import random from datetime import datetime -from typing import Any, Dict, Literal +from typing import Any, Literal from pydantic import BaseModel, ConfigDict, Field @@ -21,10 +21,10 @@ class EventContext(BaseModel): session_id: str | None = None workflow_id: str | None = None - # request_id: Optional[str] = None - # parent_event_id: Optional[str] = None - # correlation_id: Optional[str] = None - # user_id: Optional[str] = None + # request_id: str | None = None + # parent_event_id: str | None = None + # correlation_id: str | None = None + # user_id: str | None = None model_config = ConfigDict(extra="allow", arbitrary_types_allowed=True) @@ -40,7 +40,7 @@ class Event(BaseModel): namespace: str message: str timestamp: datetime = Field(default_factory=datetime.now) - data: Dict[str, Any] = Field(default_factory=dict) + data: dict[str, Any] = Field(default_factory=dict) context: EventContext | None = None # For distributed tracing @@ -84,7 +84,7 @@ def matches(self, event: Event) -> bool: # 4) Minimum severity if self.min_level: - level_map: Dict[EventType, int] = { + level_map: dict[EventType, int] = { "debug": logging.DEBUG, "info": logging.INFO, "warning": logging.WARNING, diff --git a/src/fast_agent/core/logging/json_serializer.py b/src/fast_agent/core/logging/json_serializer.py index 729f8af3c..4bf0bf37f 100644 --- a/src/fast_agent/core/logging/json_serializer.py +++ b/src/fast_agent/core/logging/json_serializer.py @@ -6,7 +6,7 @@ from decimal import Decimal from enum import Enum from pathlib import Path -from typing import Any, Dict, Iterable, Set +from typing import Any, Iterable from uuid import UUID import httpx @@ -37,7 +37,7 @@ class JSONSerializer: def __init__(self) -> None: # Set of already processed objects to prevent infinite recursion - self._processed_objects: Set[int] = set() + self._processed_objects: set[int] = set() # Check if secrets should be logged in full self._log_secrets = os.getenv("LOG_SECRETS", "").upper() == "TRUE" @@ -126,7 +126,7 @@ def _serialize_object(self, obj: Any, depth: int = 0) -> Any: return self._serialize_object(obj.to_dict()) # Handle dictionaries with sensitive data redaction - if isinstance(obj, Dict): + if isinstance(obj, dict): return { str(key): self._redact_sensitive_value(value) if self._is_sensitive_key(key) diff --git a/src/fast_agent/core/logging/listeners.py b/src/fast_agent/core/logging/listeners.py index b486abc01..5154aefe0 100644 --- a/src/fast_agent/core/logging/listeners.py +++ b/src/fast_agent/core/logging/listeners.py @@ -6,7 +6,7 @@ import logging import time from abc import ABC, abstractmethod -from typing import TYPE_CHECKING, Dict, List +from typing import TYPE_CHECKING if TYPE_CHECKING: from fast_agent.event_progress import ProgressEvent @@ -164,7 +164,7 @@ def __init__( self.logger = logger or logging.getLogger("fast_agent") async def handle_matched_event(self, event) -> None: - level_map: Dict[EventType, int] = { + level_map: dict[EventType, int] = { "debug": logging.DEBUG, "info": logging.INFO, "warning": logging.WARNING, @@ -246,7 +246,7 @@ def __init__( super().__init__(event_filter=event_filter) self.batch_size = batch_size self.flush_interval = flush_interval - self.batch: List[Event] = [] + self.batch: list[Event] = [] self.last_flush: float = time.time() # Time of last flush self._flush_task: asyncio.Task | None = None # Task for periodic flush loop self._stop_event = None # Event to signal flush task to stop @@ -293,5 +293,5 @@ async def flush(self) -> None: self.last_flush = time.time() await self._process_batch(to_process) - async def _process_batch(self, events: List[Event]) -> None: + async def _process_batch(self, events: list[Event]) -> None: pass diff --git a/src/fast_agent/core/logging/logger.py b/src/fast_agent/core/logging/logger.py index 954c9a0e1..a19f182e4 100644 --- a/src/fast_agent/core/logging/logger.py +++ b/src/fast_agent/core/logging/logger.py @@ -12,7 +12,7 @@ import threading import time from contextlib import asynccontextmanager, contextmanager -from typing import Any, Dict +from typing import Any from fast_agent.core.logging.events import Event, EventContext, EventFilter, EventType from fast_agent.core.logging.listeners import ( @@ -257,7 +257,7 @@ async def managed(cls, **config_kwargs): _logger_lock = threading.Lock() -_loggers: Dict[str, Logger] = {} +_loggers: dict[str, Logger] = {} def get_logger(namespace: str) -> Logger: @@ -304,7 +304,7 @@ def get_logger(namespace: str) -> Logger: # class Workflow: # """Example workflow that logs multiple steps, also with optional tracing.""" -# def __init__(self, name: str, steps: List[str]): +# def __init__(self, name: str, steps: list[str]): # self.logger = Logger(f"workflow.{name}") # self.steps = steps diff --git a/src/fast_agent/core/logging/transport.py b/src/fast_agent/core/logging/transport.py index d6604c13f..71f210855 100644 --- a/src/fast_agent/core/logging/transport.py +++ b/src/fast_agent/core/logging/transport.py @@ -9,7 +9,7 @@ import traceback from abc import ABC, abstractmethod from pathlib import Path -from typing import Dict, List, Protocol +from typing import Protocol import aiohttp from opentelemetry import trace @@ -71,7 +71,7 @@ def __init__(self, event_filter: EventFilter | None = None) -> None: super().__init__(event_filter=event_filter) # Use shared console instances self._serializer = JSONSerializer() - self.log_level_styles: Dict[str, str] = { + self.log_level_styles: dict[str, str] = { "info": "bold green", "debug": "dim white", "warning": "bold yellow", @@ -182,7 +182,7 @@ class HTTPTransport(FilteredEventTransport): def __init__( self, endpoint: str, - headers: Dict[str, str] = None, + headers: dict[str, str] = None, batch_size: int = 100, timeout: float = 5.0, event_filter: EventFilter | None = None, @@ -193,7 +193,7 @@ def __init__( self.batch_size = batch_size self.timeout = timeout - self.batch: List[Event] = [] + self.batch: list[Event] = [] self.lock = asyncio.Lock() self._session: aiohttp.ClientSession | None = None self._serializer = JSONSerializer() @@ -268,7 +268,7 @@ class AsyncEventBus: def __init__(self, transport: EventTransport | None = None) -> None: self.transport: EventTransport = transport or NoOpTransport() - self.listeners: Dict[str, EventListener] = {} + self.listeners: dict[str, EventListener] = {} self._queue: asyncio.Queue | None = None self._task: asyncio.Task | None = None self._running = False diff --git a/src/fast_agent/core/validation.py b/src/fast_agent/core/validation.py index 66cb5c25f..5479d78f4 100644 --- a/src/fast_agent/core/validation.py +++ b/src/fast_agent/core/validation.py @@ -2,7 +2,7 @@ Validation utilities for FastAgent configuration and dependencies. """ -from typing import Any, Dict, List +from typing import Any from fast_agent.agents.agent_types import AgentType from fast_agent.core.exceptions import ( @@ -13,7 +13,7 @@ from fast_agent.llm.fastagent_llm import FastAgentLLM -def validate_server_references(context, agents: Dict[str, Dict[str, Any]]) -> None: +def validate_server_references(context, agents: dict[str, dict[str, Any]]) -> None: """ Validate that all server references in agent configurations exist in config. Raises ServerConfigError if any referenced servers are not defined. @@ -39,7 +39,7 @@ def validate_server_references(context, agents: Dict[str, Dict[str, Any]]) -> No ) -def validate_workflow_references(agents: Dict[str, Dict[str, Any]]) -> None: +def validate_workflow_references(agents: dict[str, dict[str, Any]]) -> None: """ Validate that all workflow references point to valid agents/workflows. Also validates that referenced agents have required configuration. @@ -140,11 +140,11 @@ def validate_workflow_references(agents: Dict[str, Dict[str, Any]]) -> None: def get_dependencies( name: str, - agents: Dict[str, Dict[str, Any]], + agents: dict[str, dict[str, Any]], visited: set, path: set, agent_type: AgentType = None, -) -> List[str]: +) -> list[str]: """ Get dependencies for an agent in topological order. Works for both Parallel and Chain workflows. @@ -229,8 +229,8 @@ def get_agent_dependencies(agent_data: dict[str, Any]) -> set[str]: def get_dependencies_groups( - agents_dict: Dict[str, Dict[str, Any]], allow_cycles: bool = False -) -> List[List[str]]: + agents_dict: dict[str, dict[str, Any]], allow_cycles: bool = False +) -> list[list[str]]: """ Get dependencies between agents and group them into dependency layers. Each layer can be initialized in parallel. @@ -305,7 +305,7 @@ def visit(node) -> None: return result -def validate_provider_keys_post_creation(active_agents: Dict[str, Any]) -> None: +def validate_provider_keys_post_creation(active_agents: dict[str, Any]) -> None: """ Validate that API keys are available for all created agents with LLMs. diff --git a/src/fast_agent/event_progress.py b/src/fast_agent/event_progress.py index 0cd4c7d34..bc29dfcf9 100644 --- a/src/fast_agent/event_progress.py +++ b/src/fast_agent/event_progress.py @@ -1,7 +1,6 @@ """Module for converting log events to progress events.""" from enum import Enum -from typing import Optional from pydantic import BaseModel @@ -35,11 +34,11 @@ class ProgressEvent(BaseModel): action: ProgressAction target: str - details: Optional[str] = None - agent_name: Optional[str] = None - streaming_tokens: Optional[str] = None # Special field for streaming token count - progress: Optional[float] = None # Current progress value - total: Optional[float] = None # Total value for progress calculation + details: str | None = None + agent_name: str | None = None + streaming_tokens: str | None = None # Special field for streaming token count + progress: float | None = None # Current progress value + total: float | None = None # Total value for progress calculation def __str__(self) -> str: """Format the progress event for display.""" diff --git a/src/fast_agent/history/history_exporter.py b/src/fast_agent/history/history_exporter.py index 40a7057d5..edca5cb6a 100644 --- a/src/fast_agent/history/history_exporter.py +++ b/src/fast_agent/history/history_exporter.py @@ -8,7 +8,7 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING from fast_agent.mcp.prompt_serialization import save_messages @@ -20,7 +20,7 @@ class HistoryExporter: """Utility for exporting agent history to a file.""" @staticmethod - async def save(agent: AgentProtocol, filename: Optional[str] = None) -> str: + async def save(agent: AgentProtocol, filename: str | None = None) -> str: """ Save the given agent's message history to a file. diff --git a/src/fast_agent/human_input/elicitation_handler.py b/src/fast_agent/human_input/elicitation_handler.py index b63eebf60..452b9568c 100644 --- a/src/fast_agent/human_input/elicitation_handler.py +++ b/src/fast_agent/human_input/elicitation_handler.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, Optional +from typing import Any from fast_agent.human_input.elicitation_state import elicitation_state from fast_agent.human_input.types import ( @@ -47,7 +47,7 @@ async def elicitation_input_callback( ) # Get the elicitation schema from metadata - schema: Optional[Dict[str, Any]] = None + schema: dict[str, Any] | None = None if request.metadata and "requested_schema" in request.metadata: schema = request.metadata["requested_schema"] diff --git a/src/fast_agent/human_input/elicitation_state.py b/src/fast_agent/human_input/elicitation_state.py index 6dc91c365..c11a7d54f 100644 --- a/src/fast_agent/human_input/elicitation_state.py +++ b/src/fast_agent/human_input/elicitation_state.py @@ -1,13 +1,12 @@ """Simple state management for elicitation Cancel All functionality.""" -from typing import Set class ElicitationState: """Manages global state for elicitation requests, including disabled servers.""" def __init__(self): - self.disabled_servers: Set[str] = set() + self.disabled_servers: set[str] = set() def disable_server(self, server_name: str) -> None: """Disable elicitation requests for a specific server.""" @@ -25,7 +24,7 @@ def clear_all(self) -> None: """Clear all disabled servers.""" self.disabled_servers.clear() - def get_disabled_servers(self) -> Set[str]: + def get_disabled_servers(self) -> set[str]: """Get a copy of all disabled servers.""" return self.disabled_servers.copy() diff --git a/src/fast_agent/human_input/form_fields.py b/src/fast_agent/human_input/form_fields.py index 1579e3882..c633c961a 100644 --- a/src/fast_agent/human_input/form_fields.py +++ b/src/fast_agent/human_input/form_fields.py @@ -1,7 +1,7 @@ """High-level field types for elicitation forms with default support.""" from dataclasses import dataclass -from typing import Any, Dict, List, Optional, Union +from typing import Any, Union @dataclass @@ -16,9 +16,9 @@ class StringField: pattern: str | None = None format: str | None = None # email, uri, date, date-time - def to_schema(self) -> Dict[str, Any]: + def to_schema(self) -> dict[str, Any]: """Convert to MCP elicitation schema format.""" - schema: Dict[str, Any] = {"type": "string"} + schema: dict[str, Any] = {"type": "string"} if self.title: schema["title"] = self.title @@ -42,15 +42,15 @@ def to_schema(self) -> Dict[str, Any]: class IntegerField: """Integer field with validation and default support.""" - title: Optional[str] = None - description: Optional[str] = None - default: Optional[int] = None - minimum: Optional[int] = None - maximum: Optional[int] = None + title: str | None = None + description: str | None = None + default: int | None = None + minimum: int | None = None + maximum: int | None = None - def to_schema(self) -> Dict[str, Any]: + def to_schema(self) -> dict[str, Any]: """Convert to MCP elicitation schema format.""" - schema: Dict[str, Any] = {"type": "integer"} + schema: dict[str, Any] = {"type": "integer"} if self.title: schema["title"] = self.title @@ -70,15 +70,15 @@ def to_schema(self) -> Dict[str, Any]: class NumberField: """Number (float) field with validation and default support.""" - title: Optional[str] = None - description: Optional[str] = None - default: Optional[float] = None - minimum: Optional[float] = None - maximum: Optional[float] = None + title: str | None = None + description: str | None = None + default: float | None = None + minimum: float | None = None + maximum: float | None = None - def to_schema(self) -> Dict[str, Any]: + def to_schema(self) -> dict[str, Any]: """Convert to MCP elicitation schema format.""" - schema: Dict[str, Any] = {"type": "number"} + schema: dict[str, Any] = {"type": "number"} if self.title: schema["title"] = self.title @@ -98,13 +98,13 @@ def to_schema(self) -> Dict[str, Any]: class BooleanField: """Boolean field with default support.""" - title: Optional[str] = None - description: Optional[str] = None - default: Optional[bool] = None + title: str | None = None + description: str | None = None + default: bool | None = None - def to_schema(self) -> Dict[str, Any]: + def to_schema(self) -> dict[str, Any]: """Convert to MCP elicitation schema format.""" - schema: Dict[str, Any] = {"type": "boolean"} + schema: dict[str, Any] = {"type": "boolean"} if self.title: schema["title"] = self.title @@ -120,15 +120,15 @@ def to_schema(self) -> Dict[str, Any]: class EnumField: """Enum/choice field with default support.""" - choices: List[str] - choice_names: Optional[List[str]] = None # Human-readable names - title: Optional[str] = None - description: Optional[str] = None - default: Optional[str] = None + choices: list[str] + choice_names: list[str] | None = None # Human-readable names + title: str | None = None + description: str | None = None + default: str | None = None - def to_schema(self) -> Dict[str, Any]: + def to_schema(self) -> dict[str, Any]: """Convert to MCP elicitation schema format.""" - schema: Dict[str, Any] = {"type": "string", "enum": self.choices} + schema: dict[str, Any] = {"type": "string", "enum": self.choices} if self.title: schema["title"] = self.title @@ -152,21 +152,21 @@ class FormSchema: def __init__(self, **fields: FieldType): """Create a form schema with named fields.""" self.fields = fields - self._required_fields: List[str] = [] + self._required_fields: list[str] = [] def required(self, *field_names: str) -> "FormSchema": """Mark fields as required.""" self._required_fields.extend(field_names) return self - def to_schema(self) -> Dict[str, Any]: + def to_schema(self) -> dict[str, Any]: """Convert to MCP ElicitRequestedSchema format.""" properties = {} for field_name, field in self.fields.items(): properties[field_name] = field.to_schema() - schema: Dict[str, Any] = {"type": "object", "properties": properties} + schema: dict[str, Any] = {"type": "object", "properties": properties} if self._required_fields: schema["required"] = self._required_fields @@ -176,81 +176,81 @@ def to_schema(self) -> Dict[str, Any]: # Convenience functions for creating fields def string( - title: Optional[str] = None, - description: Optional[str] = None, - default: Optional[str] = None, - min_length: Optional[int] = None, - max_length: Optional[int] = None, - pattern: Optional[str] = None, - format: Optional[str] = None, + title: str | None = None, + description: str | None = None, + default: str | None = None, + min_length: int | None = None, + max_length: int | None = None, + pattern: str | None = None, + format: str | None = None, ) -> StringField: """Create a string field.""" return StringField(title, description, default, min_length, max_length, pattern, format) def email( - title: Optional[str] = None, description: Optional[str] = None, default: Optional[str] = None + title: str | None = None, description: str | None = None, default: str | None = None ) -> StringField: """Create an email field.""" return StringField(title, description, default, format="email") def url( - title: Optional[str] = None, description: Optional[str] = None, default: Optional[str] = None + title: str | None = None, description: str | None = None, default: str | None = None ) -> StringField: """Create a URL field.""" return StringField(title, description, default, format="uri") def date( - title: Optional[str] = None, description: Optional[str] = None, default: Optional[str] = None + title: str | None = None, description: str | None = None, default: str | None = None ) -> StringField: """Create a date field.""" return StringField(title, description, default, format="date") def datetime( - title: Optional[str] = None, description: Optional[str] = None, default: Optional[str] = None + title: str | None = None, description: str | None = None, default: str | None = None ) -> StringField: """Create a datetime field.""" return StringField(title, description, default, format="date-time") def integer( - title: Optional[str] = None, - description: Optional[str] = None, - default: Optional[int] = None, - minimum: Optional[int] = None, - maximum: Optional[int] = None, + title: str | None = None, + description: str | None = None, + default: int | None = None, + minimum: int | None = None, + maximum: int | None = None, ) -> IntegerField: """Create an integer field.""" return IntegerField(title, description, default, minimum, maximum) def number( - title: Optional[str] = None, - description: Optional[str] = None, - default: Optional[float] = None, - minimum: Optional[float] = None, - maximum: Optional[float] = None, + title: str | None = None, + description: str | None = None, + default: float | None = None, + minimum: float | None = None, + maximum: float | None = None, ) -> NumberField: """Create a number field.""" return NumberField(title, description, default, minimum, maximum) def boolean( - title: Optional[str] = None, description: Optional[str] = None, default: Optional[bool] = None + title: str | None = None, description: str | None = None, default: bool | None = None ) -> BooleanField: """Create a boolean field.""" return BooleanField(title, description, default) def choice( - choices: List[str], - choice_names: Optional[List[str]] = None, - title: Optional[str] = None, - description: Optional[str] = None, - default: Optional[str] = None, + choices: list[str], + choice_names: list[str] | None = None, + title: str | None = None, + description: str | None = None, + default: str | None = None, ) -> EnumField: """Create a choice/enum field.""" return EnumField(choices, choice_names, title, description, default) diff --git a/src/fast_agent/human_input/simple_form.py b/src/fast_agent/human_input/simple_form.py index b779a8cb6..d0612a3f6 100644 --- a/src/fast_agent/human_input/simple_form.py +++ b/src/fast_agent/human_input/simple_form.py @@ -1,7 +1,7 @@ """Simple form API for elicitation schemas without MCP wrappers.""" import asyncio -from typing import Any, Dict, Optional, Union +from typing import Any, Union from mcp.types import ElicitRequestedSchema @@ -9,10 +9,10 @@ async def form( - schema: Union[FormSchema, ElicitRequestedSchema, Dict[str, Any]], + schema: Union[FormSchema, ElicitRequestedSchema, dict[str, Any]], message: str = "Please fill out the form", title: str = "Form Input", -) -> Optional[Dict[str, Any]]: +) -> dict[str, Any] | None: """ Simple form API that presents an elicitation form and returns results. @@ -61,10 +61,10 @@ async def form( def form_sync( - schema: Union[FormSchema, ElicitRequestedSchema, Dict[str, Any]], + schema: Union[FormSchema, ElicitRequestedSchema, dict[str, Any]], message: str = "Please fill out the form", title: str = "Form Input", -) -> Optional[Dict[str, Any]]: +) -> dict[str, Any] | None: """ Synchronous wrapper for the form function. @@ -81,9 +81,9 @@ def form_sync( # Convenience function with a shorter name async def ask( - schema: Union[FormSchema, ElicitRequestedSchema, Dict[str, Any]], + schema: Union[FormSchema, ElicitRequestedSchema, dict[str, Any]], message: str = "Please provide the requested information", -) -> Optional[Dict[str, Any]]: +) -> dict[str, Any] | None: """ Short alias for form() function. @@ -101,9 +101,9 @@ async def ask( def ask_sync( - schema: Union[FormSchema, ElicitRequestedSchema, Dict[str, Any]], + schema: Union[FormSchema, ElicitRequestedSchema, dict[str, Any]], message: str = "Please provide the requested information", -) -> Optional[Dict[str, Any]]: +) -> dict[str, Any] | None: """ Synchronous version of ask(). diff --git a/src/fast_agent/interfaces.py b/src/fast_agent/interfaces.py index 55038fc62..7b960e718 100644 --- a/src/fast_agent/interfaces.py +++ b/src/fast_agent/interfaces.py @@ -9,12 +9,9 @@ TYPE_CHECKING, Any, Callable, - Dict, - List, Mapping, Protocol, Sequence, - Tuple, Type, TypeVar, Union, @@ -67,16 +64,16 @@ class FastAgentLLMProtocol(Protocol): async def structured( self, - messages: List[PromptMessageExtended], + messages: list[PromptMessageExtended], model: Type[ModelT], request_params: RequestParams | None = None, - ) -> Tuple[ModelT | None, PromptMessageExtended]: ... + ) -> tuple[ModelT | None, PromptMessageExtended]: ... async def generate( self, - messages: List[PromptMessageExtended], + messages: list[PromptMessageExtended], request_params: RequestParams | None = None, - tools: List[Tool] | None = None, + tools: list[Tool] | None = None, ) -> PromptMessageExtended: ... async def apply_prompt_template( @@ -91,11 +88,11 @@ def get_request_params( def add_stream_listener(self, listener: Callable[[str], None]) -> Callable[[], None]: ... def add_tool_stream_listener( - self, listener: Callable[[str, Dict[str, Any] | None], None] + self, listener: Callable[[str, dict[str, Any] | None], None] ) -> Callable[[], None]: ... @property - def message_history(self) -> List[PromptMessageExtended]: ... + def message_history(self) -> list[PromptMessageExtended]: ... def pop_last_message(self) -> PromptMessageExtended | None: ... @@ -182,10 +179,10 @@ async def structured( ], model: Type[ModelT], request_params: RequestParams | None = None, - ) -> Tuple[ModelT | None, PromptMessageExtended]: ... + ) -> tuple[ModelT | None, PromptMessageExtended]: ... @property - def message_history(self) -> List[PromptMessageExtended]: ... + def message_history(self) -> list[PromptMessageExtended]: ... @property def usage_accumulator(self) -> UsageAccumulator | None: ... @@ -193,7 +190,7 @@ def usage_accumulator(self) -> UsageAccumulator | None: ... async def apply_prompt( self, prompt: Union[str, "GetPromptResult"], - arguments: Dict[str, str] | None = None, + arguments: dict[str, str] | None = None, as_template: bool = False, namespace: str | None = None, ) -> str: ... @@ -201,15 +198,15 @@ async def apply_prompt( async def get_prompt( self, prompt_name: str, - arguments: Dict[str, str] | None = None, + arguments: dict[str, str] | None = None, namespace: str | None = None, ) -> GetPromptResult: ... - async def list_prompts(self, namespace: str | None = None) -> Mapping[str, List[Prompt]]: ... + async def list_prompts(self, namespace: str | None = None) -> Mapping[str, list[Prompt]]: ... - async def list_resources(self, namespace: str | None = None) -> Mapping[str, List[str]]: ... + async def list_resources(self, namespace: str | None = None) -> Mapping[str, list[str]]: ... - async def list_mcp_tools(self, namespace: str | None = None) -> Mapping[str, List[Tool]]: ... + async def list_mcp_tools(self, namespace: str | None = None) -> Mapping[str, list[Tool]]: ... async def list_tools(self) -> ListToolsResult: ... @@ -235,8 +232,8 @@ async def run_tools(self, request: PromptMessageExtended) -> PromptMessageExtend async def show_assistant_message( self, message: PromptMessageExtended, - bottom_items: List[str] | None = None, - highlight_items: str | List[str] | None = None, + bottom_items: list[str] | None = None, + highlight_items: str | list[str] | None = None, max_item_length: int | None = None, name: str | None = None, model: str | None = None, @@ -262,5 +259,5 @@ class StreamingAgentProtocol(AgentProtocol, Protocol): def add_stream_listener(self, listener: Callable[[str], None]) -> Callable[[], None]: ... def add_tool_stream_listener( - self, listener: Callable[[str, Dict[str, Any] | None], None] + self, listener: Callable[[str, dict[str, Any] | None], None] ) -> Callable[[], None]: ... diff --git a/src/fast_agent/llm/fastagent_llm.py b/src/fast_agent/llm/fastagent_llm.py index 74599eb53..925614231 100644 --- a/src/fast_agent/llm/fastagent_llm.py +++ b/src/fast_agent/llm/fastagent_llm.py @@ -6,13 +6,10 @@ TYPE_CHECKING, Any, Callable, - Dict, Generic, - List, - Optional, - Tuple, Type, TypeVar, + Union, cast, ) @@ -56,10 +53,10 @@ # Context variable for storing MCP metadata -_mcp_metadata_var: ContextVar[Dict[str, Any] | None] = ContextVar("mcp_metadata", default=None) +_mcp_metadata_var: ContextVar[dict[str, Any] | None] = ContextVar("mcp_metadata", default=None) -def deep_merge(dict1: Dict[Any, Any], dict2: Dict[Any, Any]) -> Dict[Any, Any]: +def deep_merge(dict1: dict[Any, Any], dict2: dict[Any, Any]) -> dict[Any, Any]: """ Recursively merges `dict2` into `dict1` in place. @@ -110,9 +107,9 @@ def __init__( instruction: str | None = None, name: str | None = None, request_params: RequestParams | None = None, - context: Optional["Context"] = None, - model: Optional[str] = None, - api_key: Optional[str] = None, + context: Union["Context", None] = None, + model: str | None = None, + api_key: str | None = None, **kwargs: dict[str, Any], ) -> None: """ @@ -155,7 +152,7 @@ def __init__( ) # Cache effective model name for type-safe access - self._model_name: Optional[str] = getattr(self.default_request_params, "model", None) + self._model_name: str | None = getattr(self.default_request_params, "model", None) self.verb = kwargs.get("verb") @@ -164,7 +161,7 @@ def __init__( # Initialize usage tracking self._usage_accumulator = UsageAccumulator() self._stream_listeners: set[Callable[[str], None]] = set() - self._tool_stream_listeners: set[Callable[[str, Dict[str, Any] | None], None]] = set() + self._tool_stream_listeners: set[Callable[[str, dict[str, Any] | None], None]] = set() def _initialize_default_params(self, kwargs: dict) -> RequestParams: """Initialize default parameters for the LLM. @@ -184,15 +181,15 @@ def _initialize_default_params(self, kwargs: dict) -> RequestParams: async def generate( self, - messages: List[PromptMessageExtended], + messages: list[PromptMessageExtended], request_params: RequestParams | None = None, - tools: List[Tool] | None = None, + tools: list[Tool] | None = None, ) -> PromptMessageExtended: """ Generate a completion using normalized message lists. This is the primary LLM interface that works directly with - List[PromptMessageExtended] for efficient internal usage. + list[PromptMessageExtended] for efficient internal usage. Args: messages: List of PromptMessageExtended objects @@ -251,9 +248,9 @@ async def generate( @abstractmethod async def _apply_prompt_provider_specific( self, - multipart_messages: List["PromptMessageExtended"], + multipart_messages: list["PromptMessageExtended"], request_params: RequestParams | None = None, - tools: List[Tool] | None = None, + tools: list[Tool] | None = None, is_template: bool = False, ) -> PromptMessageExtended: """ @@ -272,15 +269,15 @@ async def _apply_prompt_provider_specific( async def structured( self, - messages: List[PromptMessageExtended], + messages: list[PromptMessageExtended], model: Type[ModelT], request_params: RequestParams | None = None, - ) -> Tuple[ModelT | None, PromptMessageExtended]: + ) -> tuple[ModelT | None, PromptMessageExtended]: """ Generate a structured response using normalized message lists. This is the primary LLM interface for structured output that works directly with - List[PromptMessageExtended] for efficient internal usage. + list[PromptMessageExtended] for efficient internal usage. Args: messages: List of PromptMessageExtended objects @@ -363,10 +360,10 @@ def model_to_schema_str( async def _apply_prompt_provider_specific_structured( self, - multipart_messages: List[PromptMessageExtended], + multipart_messages: list[PromptMessageExtended], model: Type[ModelT], request_params: RequestParams | None = None, - ) -> Tuple[ModelT | None, PromptMessageExtended]: + ) -> tuple[ModelT | None, PromptMessageExtended]: """Base class attempts to parse JSON - subclasses can use provider specific functionality""" request_params = self.get_request_params(request_params) @@ -383,7 +380,7 @@ async def _apply_prompt_provider_specific_structured( def _structured_from_multipart( self, message: PromptMessageExtended, model: Type[ModelT] - ) -> Tuple[ModelT | None, PromptMessageExtended]: + ) -> tuple[ModelT | None, PromptMessageExtended]: """Parse the content of a PromptMessage and return the structured model and message itself""" try: text = get_text(message.content[-1]) or "" @@ -400,11 +397,11 @@ def _prepare_structured_text(self, text: str) -> str: """Hook for subclasses to adjust structured output text before parsing.""" return text - def record_templates(self, templates: List[PromptMessageExtended]) -> None: + def record_templates(self, templates: list[PromptMessageExtended]) -> None: """Hook for providers that need template visibility (e.g., caching).""" return - def _precall(self, multipart_messages: List[PromptMessageExtended]) -> None: + def _precall(self, multipart_messages: list[PromptMessageExtended]) -> None: """Pre-call hook to modify the message before sending it to the provider.""" # No-op placeholder; history is managed by the agent @@ -503,7 +500,7 @@ def _finalize_turn_usage(self, turn_usage: "TurnUsage") -> None: self._usage_accumulator.add_turn(turn_usage) def _log_chat_progress( - self, chat_turn: Optional[int] = None, model: Optional[str] = None + self, chat_turn: int | None = None, model: str | None = None ) -> None: """Log a chat progress event""" # Determine action type based on verb @@ -582,7 +579,7 @@ def _notify_stream_listeners(self, chunk: str) -> None: self.logger.exception("Stream listener raised an exception") def add_tool_stream_listener( - self, listener: Callable[[str, Dict[str, Any] | None], None] + self, listener: Callable[[str, dict[str, Any] | None], None] ) -> Callable[[], None]: """Register a callback invoked with tool streaming events. @@ -601,7 +598,7 @@ def remove() -> None: return remove def _notify_tool_stream_listeners( - self, event_type: str, payload: Dict[str, Any] | None = None + self, event_type: str, payload: dict[str, Any] | None = None ) -> None: """Notify listeners about tool streaming lifecycle events.""" @@ -612,7 +609,7 @@ def _notify_tool_stream_listeners( except Exception: self.logger.exception("Tool stream listener raised an exception") - def _log_chat_finished(self, model: Optional[str] = None) -> None: + def _log_chat_finished(self, model: str | None = None) -> None: """Log a chat finished event""" data = { "progress_action": ProgressAction.READY, @@ -621,7 +618,7 @@ def _log_chat_finished(self, model: Optional[str] = None) -> None: } self.logger.debug("Chat finished", data=data) - def _convert_prompt_messages(self, prompt_messages: List[PromptMessage]) -> List[MessageParamT]: + def _convert_prompt_messages(self, prompt_messages: list[PromptMessage]) -> list[MessageParamT]: """ Convert prompt messages to this LLM's specific message format. To be implemented by concrete LLM classes. @@ -629,8 +626,8 @@ def _convert_prompt_messages(self, prompt_messages: List[PromptMessage]) -> List raise NotImplementedError("Must be implemented by subclass") def _convert_to_provider_format( - self, messages: List[PromptMessageExtended] - ) -> List[MessageParamT]: + self, messages: list[PromptMessageExtended] + ) -> list[MessageParamT]: """ Convert provided messages to provider-specific format. Called fresh on EVERY API call - no caching. @@ -645,8 +642,8 @@ def _convert_to_provider_format( @abstractmethod def _convert_extended_messages_to_provider( - self, messages: List[PromptMessageExtended] - ) -> List[MessageParamT]: + self, messages: list[PromptMessageExtended] + ) -> list[MessageParamT]: """ Convert PromptMessageExtended list to provider-specific format. Must be implemented by each provider. @@ -662,9 +659,9 @@ def _convert_extended_messages_to_provider( async def show_prompt_loaded( self, prompt_name: str, - description: Optional[str] = None, + description: str | None = None, message_count: int = 0, - arguments: Optional[dict[str, str]] = None, + arguments: dict[str, str] | None = None, ) -> None: """ Display information about a loaded prompt template. @@ -721,7 +718,7 @@ async def apply_prompt_template(self, prompt_result: GetPromptResult, prompt_nam ) return result.first_text() - async def _save_history(self, filename: str, messages: List[PromptMessageExtended]) -> None: + async def _save_history(self, filename: str, messages: list[PromptMessageExtended]) -> None: """ Save the Message History to a file in a format determined by the file extension. @@ -741,7 +738,7 @@ async def _save_history(self, filename: str, messages: List[PromptMessageExtende save_messages(filtered, filename) @property - def message_history(self) -> List[PromptMessageExtended]: + def message_history(self) -> list[PromptMessageExtended]: """ Return the agent's message history as PromptMessageExtended objects. diff --git a/src/fast_agent/llm/internal/passthrough.py b/src/fast_agent/llm/internal/passthrough.py index 52fb5b1d9..612c492e8 100644 --- a/src/fast_agent/llm/internal/passthrough.py +++ b/src/fast_agent/llm/internal/passthrough.py @@ -1,5 +1,5 @@ import json # Import at the module level -from typing import Any, Dict, List, Optional +from typing import Any from mcp import CallToolRequest, Tool from mcp.types import CallToolRequestParams, PromptMessage @@ -41,7 +41,7 @@ def __init__( async def initialize(self) -> None: pass - def _parse_tool_command(self, command: str) -> tuple[str, Optional[dict]]: + def _parse_tool_command(self, command: str) -> tuple[str, dict | None]: """ Parse a tool command string into tool name and arguments. @@ -72,7 +72,7 @@ def _parse_tool_command(self, command: str) -> tuple[str, Optional[dict]]: async def _apply_prompt_provider_specific( self, - multipart_messages: List["PromptMessageExtended"], + multipart_messages: list["PromptMessageExtended"], request_params: RequestParams | None = None, tools: list[Tool] | None = None, is_template: bool = False, @@ -85,7 +85,7 @@ async def _apply_prompt_provider_specific( if last_message.role == "assistant": return last_message - tool_calls: Dict[str, CallToolRequest] = {} + tool_calls: dict[str, CallToolRequest] = {} stop_reason: LlmStopReason = LlmStopReason.END_TURN if self.is_tool_call(last_message): tool_name, arguments = self._parse_tool_command(last_message.first_text()) @@ -143,8 +143,8 @@ async def _apply_prompt_provider_specific( return result def _convert_extended_messages_to_provider( - self, messages: List[PromptMessageExtended] - ) -> List[Any]: + self, messages: list[PromptMessageExtended] + ) -> list[Any]: """ Convert PromptMessageExtended list to provider format. For PassthroughLLM, we don't actually make API calls, so this just returns empty list. diff --git a/src/fast_agent/llm/internal/playback.py b/src/fast_agent/llm/internal/playback.py index 2d5a05d1b..356bb0b38 100644 --- a/src/fast_agent/llm/internal/playback.py +++ b/src/fast_agent/llm/internal/playback.py @@ -1,4 +1,4 @@ -from typing import Any, List, Type, Union +from typing import Any, Type, Union from mcp import Tool from mcp.types import PromptMessage @@ -31,7 +31,7 @@ class PlaybackLLM(PassthroughLLM): def __init__(self, name: str = "Playback", **kwargs: dict[str, Any]) -> None: super().__init__(name=name, provider=Provider.FAST_AGENT, **kwargs) - self._messages: List[PromptMessageExtended] = [] + self._messages: list[PromptMessageExtended] = [] self._current_index = -1 self._overage = -1 @@ -60,10 +60,10 @@ async def generate( str, PromptMessage, PromptMessageExtended, - List[Union[str, PromptMessage, PromptMessageExtended]], + list[Union[str, PromptMessage, PromptMessageExtended]], ], request_params: RequestParams | None = None, - tools: List[Tool] | None = None, + tools: list[Tool] | None = None, ) -> PromptMessageExtended: """ Handle playback of messages in two modes: @@ -112,7 +112,7 @@ async def generate( async def structured( self, - messages: List[PromptMessageExtended], + messages: list[PromptMessageExtended], model: Type[ModelT], request_params: RequestParams | None = None, ) -> tuple[ModelT | None, PromptMessageExtended]: diff --git a/src/fast_agent/llm/internal/slow.py b/src/fast_agent/llm/internal/slow.py index 20d5e3786..bba08322c 100644 --- a/src/fast_agent/llm/internal/slow.py +++ b/src/fast_agent/llm/internal/slow.py @@ -1,5 +1,5 @@ import asyncio -from typing import Any, List +from typing import Any from mcp import Tool @@ -26,7 +26,7 @@ def __init__( async def _apply_prompt_provider_specific( self, - multipart_messages: List["PromptMessageExtended"], + multipart_messages: list["PromptMessageExtended"], request_params: RequestParams | None = None, tools: list[Tool] | None = None, is_template: bool = False, diff --git a/src/fast_agent/llm/memory.py b/src/fast_agent/llm/memory.py index 88367ea71..72cc30542 100644 --- a/src/fast_agent/llm/memory.py +++ b/src/fast_agent/llm/memory.py @@ -1,4 +1,4 @@ -from typing import Generic, List, Optional, Protocol, TypeVar +from typing import Generic, Protocol, TypeVar # Define our own type variable for implementation use MessageParamT = TypeVar("MessageParamT") @@ -20,17 +20,17 @@ class Memory(Protocol, Generic[MessageParamT]): def __init__(self) -> None: ... - def extend(self, messages: List[MessageParamT], is_prompt: bool = False) -> None: ... + def extend(self, messages: list[MessageParamT], is_prompt: bool = False) -> None: ... - def set(self, messages: List[MessageParamT], is_prompt: bool = False) -> None: ... + def set(self, messages: list[MessageParamT], is_prompt: bool = False) -> None: ... def append(self, message: MessageParamT, is_prompt: bool = False) -> None: ... - def get(self, include_completion_history: bool = True) -> List[MessageParamT]: ... + def get(self, include_completion_history: bool = True) -> list[MessageParamT]: ... def clear(self, clear_prompts: bool = False) -> None: ... - def pop(self, *, from_prompts: bool = False) -> Optional[MessageParamT]: ... + def pop(self, *, from_prompts: bool = False) -> MessageParamT | None: ... class SimpleMemory(Memory, Generic[MessageParamT]): @@ -42,15 +42,15 @@ class SimpleMemory(Memory, Generic[MessageParamT]): """ def __init__(self) -> None: - self.history: List[MessageParamT] = [] - self.prompt_messages: List[MessageParamT] = [] # Always included - self.conversation_cache_positions: List[ + self.history: list[MessageParamT] = [] + self.prompt_messages: list[MessageParamT] = [] # Always included + self.conversation_cache_positions: list[ int ] = [] # Track active conversation cache positions self.cache_walk_distance: int = 6 # Messages between cache blocks self.max_conversation_cache_blocks: int = 2 # Maximum conversation cache blocks - def extend(self, messages: List[MessageParamT], is_prompt: bool = False) -> None: + def extend(self, messages: list[MessageParamT], is_prompt: bool = False) -> None: """ Add multiple messages to history. @@ -63,7 +63,7 @@ def extend(self, messages: List[MessageParamT], is_prompt: bool = False) -> None else: self.history.extend(messages) - def set(self, messages: List[MessageParamT], is_prompt: bool = False) -> None: + def set(self, messages: list[MessageParamT], is_prompt: bool = False) -> None: """ Replace messages in history. @@ -89,7 +89,7 @@ def append(self, message: MessageParamT, is_prompt: bool = False) -> None: else: self.history.append(message) - def get(self, include_completion_history: bool = True) -> List[MessageParamT]: + def get(self, include_completion_history: bool = True) -> list[MessageParamT]: """ Get all messages in memory. @@ -127,7 +127,7 @@ def clear(self, clear_prompts: bool = False) -> None: if clear_prompts: self.prompt_messages = [] - def pop(self, *, from_prompts: bool = False) -> Optional[MessageParamT]: + def pop(self, *, from_prompts: bool = False) -> MessageParamT | None: """ Remove and return the most recent message from history or prompt messages. @@ -168,7 +168,7 @@ def should_apply_conversation_cache(self) -> bool: self.conversation_cache_positions ) - def _calculate_cache_positions(self, total_conversation_messages: int) -> List[int]: + def _calculate_cache_positions(self, total_conversation_messages: int) -> list[int]: """ Calculate where cache blocks should be placed using walking algorithm. @@ -227,7 +227,7 @@ def apply_conversation_cache_updates(self, updates: dict) -> None: self.conversation_cache_positions = updates["active"].copy() def remove_cache_control_from_messages( - self, messages: List[MessageParamT], positions: List[int] + self, messages: list[MessageParamT], positions: list[int] ) -> None: """ Remove cache control from specified message positions. @@ -247,7 +247,7 @@ def remove_cache_control_from_messages( del content_block["cache_control"] def add_cache_control_to_messages( - self, messages: List[MessageParamT], positions: List[int] + self, messages: list[MessageParamT], positions: list[int] ) -> int: """ Add cache control to specified message positions. diff --git a/src/fast_agent/llm/model_database.py b/src/fast_agent/llm/model_database.py index 376483061..27ebb84d8 100644 --- a/src/fast_agent/llm/model_database.py +++ b/src/fast_agent/llm/model_database.py @@ -5,7 +5,7 @@ context windows, max output tokens, and supported tokenization types. """ -from typing import Dict, List, Literal, Optional +from typing import Literal from pydantic import BaseModel @@ -19,7 +19,7 @@ class ModelParameters(BaseModel): max_output_tokens: int """Maximum output tokens the model can generate""" - tokenizes: List[str] + tokenizes: list[str] """List of supported content types for tokenization""" json_mode: None | str = "schema" @@ -207,7 +207,7 @@ class ModelDatabase: # Model configuration database # KEEP ALL LOWER CASE KEYS - MODELS: Dict[str, ModelParameters] = { + MODELS: dict[str, ModelParameters] = { # internal models "passthrough": FAST_AGENT_STANDARD, "silent": FAST_AGENT_STANDARD, @@ -307,24 +307,24 @@ class ModelDatabase: } @classmethod - def get_model_params(cls, model: str) -> Optional[ModelParameters]: + def get_model_params(cls, model: str) -> ModelParameters | None: """Get model parameters for a given model name""" return cls.MODELS.get(model.lower()) @classmethod - def get_context_window(cls, model: str) -> Optional[int]: + def get_context_window(cls, model: str) -> int | None: """Get context window size for a model""" params = cls.get_model_params(model) return params.context_window if params else None @classmethod - def get_max_output_tokens(cls, model: str) -> Optional[int]: + def get_max_output_tokens(cls, model: str) -> int | None: """Get maximum output tokens for a model""" params = cls.get_model_params(model) return params.max_output_tokens if params else None @classmethod - def get_tokenizes(cls, model: str) -> Optional[List[str]]: + def get_tokenizes(cls, model: str) -> list[str] | None: """Get supported tokenization types for a model""" params = cls.get_model_params(model) return params.tokenizes if params else None @@ -357,7 +357,7 @@ def supports_mime(cls, model: str, mime_type: str) -> bool: return normalized.lower() in normalized_supported @classmethod - def supports_any_mime(cls, model: str, mime_types: List[str]) -> bool: + def supports_any_mime(cls, model: str, mime_types: list[str]) -> bool: """Return True if the model supports any of the provided MIME types.""" return any(cls.supports_mime(model, m) for m in mime_types) @@ -394,6 +394,6 @@ def get_default_max_tokens(cls, model: str) -> int: return 2048 # Fallback for unknown models @classmethod - def list_models(cls) -> List[str]: + def list_models(cls) -> list[str]: """List all available model names""" return list(cls.MODELS.keys()) diff --git a/src/fast_agent/llm/model_factory.py b/src/fast_agent/llm/model_factory.py index 9058e6037..22592addd 100644 --- a/src/fast_agent/llm/model_factory.py +++ b/src/fast_agent/llm/model_factory.py @@ -1,5 +1,5 @@ from enum import Enum -from typing import Dict, Optional, Type, Union +from typing import Type, Union from pydantic import BaseModel @@ -155,11 +155,11 @@ def _bedrock_pattern_matches(model_name: str) -> bool: return False # Mapping of providers to their LLM classes - PROVIDER_CLASSES: Dict[Provider, LLMClass] = {} + PROVIDER_CLASSES: dict[Provider, LLMClass] = {} # Mapping of special model names to their specific LLM classes # This overrides the provider-based class selection - MODEL_SPECIFIC_CLASSES: Dict[str, LLMClass] = { + MODEL_SPECIFIC_CLASSES: dict[str, LLMClass] = { "playback": PlaybackLLM, "silent": SilentLLM, "slow": SlowLLM, @@ -264,7 +264,7 @@ def create_factory(cls, model_string: str) -> LLMFactoryProtocol: llm_class = cls.PROVIDER_CLASSES[config.provider] def factory( - agent: AgentProtocol, request_params: Optional[RequestParams] = None, **kwargs + agent: AgentProtocol, request_params: RequestParams | None = None, **kwargs ) -> FastAgentLLMProtocol: base_params = RequestParams() base_params.model = config.model_name diff --git a/src/fast_agent/llm/model_info.py b/src/fast_agent/llm/model_info.py index 2390647dc..65984f698 100644 --- a/src/fast_agent/llm/model_info.py +++ b/src/fast_agent/llm/model_info.py @@ -8,7 +8,7 @@ from __future__ import annotations from dataclasses import dataclass -from typing import TYPE_CHECKING, List, Optional +from typing import TYPE_CHECKING from fast_agent.llm.model_database import ModelDatabase from fast_agent.llm.model_factory import ModelFactory @@ -25,11 +25,11 @@ class ModelInfo: name: str provider: Provider - context_window: Optional[int] - max_output_tokens: Optional[int] - tokenizes: List[str] - json_mode: Optional[str] - reasoning: Optional[str] + context_window: int | None + max_output_tokens: int | None + tokenizes: list[str] + json_mode: str | None + reasoning: str | None @property def supports_text(self) -> bool: @@ -62,7 +62,7 @@ def tdv_flags(self) -> tuple[bool, bool, bool]: return (self.supports_text, self.supports_document, self.supports_vision) @classmethod - def from_llm(cls, llm: "FastAgentLLMProtocol") -> Optional["ModelInfo"]: + def from_llm(cls, llm: "FastAgentLLMProtocol") -> "ModelInfo" | None: name = getattr(llm, "model_name", None) provider = getattr(llm, "provider", None) if not name or not provider: @@ -70,7 +70,7 @@ def from_llm(cls, llm: "FastAgentLLMProtocol") -> Optional["ModelInfo"]: return cls.from_name(name, provider) @classmethod - def from_name(cls, name: str, provider: Provider | None = None) -> Optional["ModelInfo"]: + def from_name(cls, name: str, provider: Provider | None = None) -> "ModelInfo" | None: canonical_name = ModelFactory.MODEL_ALIASES.get(name, name) params = ModelDatabase.get_model_params(canonical_name) if not params: diff --git a/src/fast_agent/llm/prompt_utils.py b/src/fast_agent/llm/prompt_utils.py index a07d4bf6c..73881f1e8 100644 --- a/src/fast_agent/llm/prompt_utils.py +++ b/src/fast_agent/llm/prompt_utils.py @@ -2,13 +2,13 @@ XML formatting utilities for consistent prompt engineering across components. """ -from typing import Dict, List, Optional, TypedDict +from typing import TypedDict def format_xml_tag( tag_name: str, - content: Optional[str] = None, - attributes: Optional[Dict[str, str]] = None, + content: str | None = None, + attributes: dict[str, str] | None = None, ) -> str: """ Format an XML tag with optional content and attributes. @@ -37,8 +37,8 @@ def format_xml_tag( def format_fastagent_tag( tag_type: str, - content: Optional[str] = None, - attributes: Optional[Dict[str, str]] = None, + content: str | None = None, + attributes: dict[str, str] | None = None, ) -> str: """ Format a fastagent-namespaced XML tag with consistent formatting. @@ -56,8 +56,8 @@ def format_fastagent_tag( def format_server_info( server_name: str, - description: Optional[str] = None, - tools: Optional[List[Dict[str, str]]] = None, + description: str | None = None, + tools: list[dict[str, str]] | None = None, ) -> str: """ Format server information consistently across router and orchestrator modules. @@ -103,13 +103,13 @@ def format_server_info( class ServerInfo(TypedDict, total=False): name: str description: str - tools: List[Dict[str, str]] + tools: list[dict[str, str]] def format_agent_info( agent_name: str, - description: Optional[str] = None, - servers: Optional[List[ServerInfo]] = None, + description: str | None = None, + servers: list[ServerInfo] | None = None, ) -> str: """ Format agent information consistently across router and orchestrator modules. diff --git a/src/fast_agent/llm/provider/anthropic/cache_planner.py b/src/fast_agent/llm/provider/anthropic/cache_planner.py index fcc9135d2..254dbcc73 100644 --- a/src/fast_agent/llm/provider/anthropic/cache_planner.py +++ b/src/fast_agent/llm/provider/anthropic/cache_planner.py @@ -1,4 +1,3 @@ -from typing import List from fast_agent.mcp.prompt_message_extended import PromptMessageExtended @@ -16,15 +15,15 @@ def __init__( self.max_conversation_blocks = max_conversation_blocks self.max_total_blocks = max_total_blocks - def _template_prefix_count(self, messages: List[PromptMessageExtended]) -> int: + def _template_prefix_count(self, messages: list[PromptMessageExtended]) -> int: return sum(msg.is_template for msg in messages) def plan_indices( self, - messages: List[PromptMessageExtended], + messages: list[PromptMessageExtended], cache_mode: str, system_cache_blocks: int = 0, - ) -> List[int]: + ) -> list[int]: """Return message indices that should receive cache_control.""" if cache_mode == "off" or not messages: @@ -35,13 +34,13 @@ def plan_indices( return [] template_prefix = self._template_prefix_count(messages) - template_indices: List[int] = [] + template_indices: list[int] = [] if cache_mode in ("prompt", "auto") and template_prefix: template_indices = list(range(min(template_prefix, budget))) budget -= len(template_indices) - conversation_indices: List[int] = [] + conversation_indices: list[int] = [] if cache_mode == "auto" and budget > 0: conv_count = max(0, len(messages) - template_prefix) if conv_count >= self.walk_distance: diff --git a/src/fast_agent/llm/provider/anthropic/llm_anthropic.py b/src/fast_agent/llm/provider/anthropic/llm_anthropic.py index ea0008739..ab3f6e407 100644 --- a/src/fast_agent/llm/provider/anthropic/llm_anthropic.py +++ b/src/fast_agent/llm/provider/anthropic/llm_anthropic.py @@ -1,5 +1,5 @@ import json -from typing import Any, List, Tuple, Type, Union, cast +from typing import Any, Type, Union, cast from anthropic import APIError, AsyncAnthropic, AuthenticationError from anthropic.lib.streaming import AsyncMessageStream @@ -46,7 +46,7 @@ STRUCTURED_OUTPUT_TOOL_NAME = "return_structured_output" # Type alias for system field - can be string or list of text blocks with cache control -SystemParam = Union[str, List[TextBlockParam]] +SystemParam = Union[str, list[TextBlockParam]] logger = get_logger(__name__) @@ -96,8 +96,8 @@ def _get_cache_mode(self) -> str: return cache_mode async def _prepare_tools( - self, structured_model: Type[ModelT] | None = None, tools: List[Tool] | None = None - ) -> List[ToolParam]: + self, structured_model: Type[ModelT] | None = None, tools: list[Tool] | None = None + ) -> list[ToolParam]: """Prepare tools based on whether we're in structured output mode.""" if structured_model: return [ @@ -159,7 +159,7 @@ def _apply_cache_control_to_message(message: MessageParam) -> bool: return False - def _is_structured_output_request(self, tool_uses: List[Any]) -> bool: + def _is_structured_output_request(self, tool_uses: list[Any]) -> bool: """ Check if the tool uses contain a structured output request. @@ -171,7 +171,7 @@ def _is_structured_output_request(self, tool_uses: List[Any]) -> bool: """ return any(tool.name == STRUCTURED_OUTPUT_TOOL_NAME for tool in tool_uses) - def _build_tool_calls_dict(self, tool_uses: List[ToolUseBlock]) -> dict[str, CallToolRequest]: + def _build_tool_calls_dict(self, tool_uses: list[ToolUseBlock]) -> dict[str, CallToolRequest]: """ Convert Anthropic tool use blocks into our CallToolRequest. @@ -197,8 +197,8 @@ async def _handle_structured_output_response( self, tool_use_block: ToolUseBlock, structured_model: Type[ModelT], - messages: List[MessageParam], - ) -> Tuple[LlmStopReason, List[ContentBlock]]: + messages: list[MessageParam], + ) -> tuple[LlmStopReason, list[ContentBlock]]: """ Handle a structured output tool response from Anthropic. @@ -451,18 +451,18 @@ def _build_request_messages( self, params: RequestParams, message_param: MessageParam, - pre_messages: List[MessageParam] | None = None, - history: List[PromptMessageExtended] | None = None, - ) -> List[MessageParam]: + pre_messages: list[MessageParam] | None = None, + history: list[PromptMessageExtended] | None = None, + ) -> list[MessageParam]: """ Build the list of Anthropic message parameters for the next request. Ensures that the current user message is only included once when history is enabled, which prevents duplicate tool_result blocks from being sent. """ - messages: List[MessageParam] = list(pre_messages) if pre_messages else [] + messages: list[MessageParam] = list(pre_messages) if pre_messages else [] - history_messages: List[MessageParam] = [] + history_messages: list[MessageParam] = [] if params.use_history and history: history_messages = self._convert_to_provider_format(history) messages.extend(history_messages) @@ -478,9 +478,9 @@ async def _anthropic_completion( message_param, request_params: RequestParams | None = None, structured_model: Type[ModelT] | None = None, - tools: List[Tool] | None = None, - pre_messages: List[MessageParam] | None = None, - history: List[PromptMessageExtended] | None = None, + tools: list[Tool] | None = None, + pre_messages: list[MessageParam] | None = None, + history: list[PromptMessageExtended] | None = None, current_extended: PromptMessageExtended | None = None, ) -> PromptMessageExtended: """ @@ -509,7 +509,7 @@ async def _anthropic_completion( available_tools = await self._prepare_tools(structured_model, tools) - response_content_blocks: List[ContentBlock] = [] + response_content_blocks: list[ContentBlock] = [] tool_calls: dict[str, CallToolRequest] | None = None model = self.default_request_params.model or DEFAULT_ANTHROPIC_MODEL @@ -544,7 +544,7 @@ async def _anthropic_completion( planner = AnthropicCachePlanner( self.CONVERSATION_CACHE_WALK_DISTANCE, self.MAX_CONVERSATION_CACHE_BLOCKS ) - plan_messages: List[PromptMessageExtended] = [] + plan_messages: list[PromptMessageExtended] = [] include_current = not params.use_history or not history if params.use_history and history: plan_messages.extend(history) @@ -641,9 +641,9 @@ async def _anthropic_completion( async def _apply_prompt_provider_specific( self, - multipart_messages: List["PromptMessageExtended"], + multipart_messages: list["PromptMessageExtended"], request_params: RequestParams | None = None, - tools: List[Tool] | None = None, + tools: list[Tool] | None = None, is_template: bool = False, ) -> PromptMessageExtended: """ @@ -673,10 +673,10 @@ async def _apply_prompt_provider_specific( async def _apply_prompt_provider_specific_structured( self, - multipart_messages: List[PromptMessageExtended], + multipart_messages: list[PromptMessageExtended], model: Type[ModelT], request_params: RequestParams | None = None, - ) -> Tuple[ModelT | None, PromptMessageExtended]: # noqa: F821 + ) -> tuple[ModelT | None, PromptMessageExtended]: # noqa: F821 """ Provider-specific structured output implementation. Note: Message history is managed by base class and converted via @@ -718,8 +718,8 @@ async def _apply_prompt_provider_specific_structured( return None, last_message def _convert_extended_messages_to_provider( - self, messages: List[PromptMessageExtended] - ) -> List[MessageParam]: + self, messages: list[PromptMessageExtended] + ) -> list[MessageParam]: """ Convert PromptMessageExtended list to Anthropic MessageParam format. This is called fresh on every API call from _convert_to_provider_format(). diff --git a/src/fast_agent/llm/provider/anthropic/multipart_converter_anthropic.py b/src/fast_agent/llm/provider/anthropic/multipart_converter_anthropic.py index f9b880e81..cb63ee4e6 100644 --- a/src/fast_agent/llm/provider/anthropic/multipart_converter_anthropic.py +++ b/src/fast_agent/llm/provider/anthropic/multipart_converter_anthropic.py @@ -1,5 +1,5 @@ import re -from typing import List, Sequence, Union +from typing import Sequence, Union from anthropic.types import ( Base64ImageSourceParam, @@ -162,7 +162,7 @@ def convert_prompt_message_to_anthropic(message: PromptMessage) -> MessageParam: def _convert_content_items( content_items: Sequence[ContentBlock], document_mode: bool = True, - ) -> List[ContentBlockParam]: + ) -> list[ContentBlockParam]: """ Convert a list of content items to Anthropic content blocks. @@ -173,7 +173,7 @@ def _convert_content_items( Returns: List of Anthropic content blocks """ - anthropic_blocks: List[ContentBlockParam] = [] + anthropic_blocks: list[ContentBlockParam] = [] for content_item in content_items: if is_text_content(content_item): @@ -392,7 +392,7 @@ def _create_fallback_text( @staticmethod def create_tool_results_message( - tool_results: List[tuple[str, CallToolResult]], + tool_results: list[tuple[str, CallToolResult]], ) -> MessageParam: """ Create a user message containing tool results. diff --git a/src/fast_agent/llm/provider/bedrock/bedrock_utils.py b/src/fast_agent/llm/provider/bedrock/bedrock_utils.py index 661e9f060..806e4351f 100644 --- a/src/fast_agent/llm/provider/bedrock/bedrock_utils.py +++ b/src/fast_agent/llm/provider/bedrock/bedrock_utils.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import Collection, Dict, List, Literal, Optional, Set, TypedDict, cast +from typing import Collection, Literal, TypedDict, cast # Lightweight, runtime-only loader for AWS Bedrock models. # - Fetches once per process via boto3 (region from session; env override supported) @@ -22,18 +22,18 @@ class ModelSummary(TypedDict, total=False): modelId: str modelName: str providerName: str - inputModalities: List[Modality] - outputModalities: List[Modality] + inputModalities: list[Modality] + outputModalities: list[Modality] responseStreamingSupported: bool - customizationsSupported: List[str] - inferenceTypesSupported: List[InferenceType] - modelLifecycle: Dict[str, Lifecycle] + customizationsSupported: list[str] + inferenceTypesSupported: list[InferenceType] + modelLifecycle: dict[str, Lifecycle] -_MODELS_CACHE_BY_REGION: Dict[str, Dict[str, ModelSummary]] = {} +_MODELS_CACHE_BY_REGION: dict[str, dict[str, ModelSummary]] = {} -def _resolve_region(region: Optional[str]) -> str: +def _resolve_region(region: str | None) -> str: if region: return region import os @@ -57,7 +57,7 @@ def _strip_prefix(model_id: str, prefix: str) -> str: return model_id[len(prefix) :] if prefix and model_id.startswith(prefix) else model_id -def _ensure_loaded(region: Optional[str] = None) -> Dict[str, ModelSummary]: +def _ensure_loaded(region: str | None = None) -> dict[str, ModelSummary]: resolved_region = _resolve_region(region) cache = _MODELS_CACHE_BY_REGION.get(resolved_region) if cache is not None: @@ -69,7 +69,7 @@ def _ensure_loaded(region: Optional[str] = None) -> Dict[str, ModelSummary]: try: client = boto3.client("bedrock", region_name=resolved_region) resp = client.list_foundation_models() - summaries: List[ModelSummary] = resp.get("modelSummaries", []) # type: ignore[assignment] + summaries: list[ModelSummary] = resp.get("modelSummaries", []) # type: ignore[assignment] except Exception as exc: # keep error simple and actionable raise RuntimeError( f"Failed to list Bedrock foundation models in region '{resolved_region}'. " @@ -82,27 +82,27 @@ def _ensure_loaded(region: Optional[str] = None) -> Dict[str, ModelSummary]: return cache -def refresh_bedrock_models(region: Optional[str] = None) -> None: +def refresh_bedrock_models(region: str | None = None) -> None: resolved_region = _resolve_region(region) # drop and reload on next access _MODELS_CACHE_BY_REGION.pop(resolved_region, None) _ensure_loaded(resolved_region) -def _matches_modalities(model_modalities: List[Modality], requested: Collection[Modality]) -> bool: +def _matches_modalities(model_modalities: list[Modality], requested: Collection[Modality]) -> bool: # include if all requested are present in the model's modalities return set(requested).issubset(set(model_modalities)) def all_model_summaries( - input_modalities: Optional[Collection[Modality]] = None, - output_modalities: Optional[Collection[Modality]] = None, + input_modalities: Collection[Modality] | None = None, + output_modalities: Collection[Modality] | None = None, include_legacy: bool = False, - providers: Optional[Collection[str]] = None, - inference_types: Optional[Collection[InferenceType]] = None, + providers: Collection[str] | None = None, + inference_types: Collection[InferenceType] | None = None, direct_invocation_only: bool = True, - region: Optional[str] = None, -) -> List[ModelSummary]: + region: str | None = None, +) -> list[ModelSummary]: """Return filtered Bedrock model summaries. Defaults: input_modalities={"TEXT"}, output_modalities={"TEXT"}, include_legacy=False, @@ -110,16 +110,16 @@ def all_model_summaries( """ cache = _ensure_loaded(region) - results: List[ModelSummary] = [] + results: list[ModelSummary] = [] - effective_output: Set[Modality] = ( + effective_output: set[Modality] = ( set(output_modalities) if output_modalities is not None else {cast("Modality", "TEXT")} ) - effective_input: Optional[Set[Modality]] = ( + effective_input: set[Modality] | None = ( set(input_modalities) if input_modalities is not None else {cast("Modality", "TEXT")} ) - provider_filter: Optional[Set[str]] = set(providers) if providers is not None else None - effective_inference: Set[InferenceType] = ( + provider_filter: set[str] | None = set(providers) if providers is not None else None + effective_inference: set[InferenceType] = ( set(inference_types) if inference_types is not None else {cast("InferenceType", "ON_DEMAND")} @@ -140,8 +140,8 @@ def all_model_summaries( continue # modalities - model_inputs: List[Modality] = summary.get("inputModalities", []) # type: ignore[assignment] - model_outputs: List[Modality] = summary.get("outputModalities", []) # type: ignore[assignment] + model_inputs: list[Modality] = summary.get("inputModalities", []) # type: ignore[assignment] + model_outputs: list[Modality] = summary.get("outputModalities", []) # type: ignore[assignment] if effective_input is not None and not _matches_modalities(model_inputs, effective_input): continue @@ -149,7 +149,7 @@ def all_model_summaries( continue # inference types - model_inference: List[InferenceType] = summary.get("inferenceTypesSupported", []) # type: ignore[assignment] + model_inference: list[InferenceType] = summary.get("inferenceTypesSupported", []) # type: ignore[assignment] if effective_inference and not set(effective_inference).issubset(set(model_inference)): continue @@ -159,15 +159,15 @@ def all_model_summaries( def all_bedrock_models( - input_modalities: Optional[Collection[Modality]] = None, - output_modalities: Optional[Collection[Modality]] = None, + input_modalities: Collection[Modality] | None = None, + output_modalities: Collection[Modality] | None = None, include_legacy: bool = False, - providers: Optional[Collection[str]] = None, + providers: Collection[str] | None = None, prefix: str = "bedrock.", - inference_types: Optional[Collection[InferenceType]] = None, + inference_types: Collection[InferenceType] | None = None, direct_invocation_only: bool = True, - region: Optional[str] = None, -) -> List[str]: + region: str | None = None, +) -> list[str]: """Return model IDs (optionally prefixed) filtered by the given criteria. Defaults: output_modalities={"TEXT"}, exclude LEGACY, @@ -183,7 +183,7 @@ def all_bedrock_models( direct_invocation_only=direct_invocation_only, region=region, ) - ids: List[str] = [] + ids: list[str] = [] for s in summaries: mid = s.get("modelId") if mid: @@ -193,14 +193,14 @@ def all_bedrock_models( return ids -def get_model_metadata(model_id: str, region: Optional[str] = None) -> Optional[ModelSummary]: +def get_model_metadata(model_id: str, region: str | None = None) -> ModelSummary | None: cache = _ensure_loaded(region) # Accept either prefixed or plain model IDs plain_id = _strip_prefix(model_id, "bedrock.") return cache.get(plain_id) -def list_providers(region: Optional[str] = None) -> List[str]: +def list_providers(region: str | None = None) -> list[str]: cache = _ensure_loaded(region) providers = {s.get("providerName") for s in cache.values() if s.get("providerName")} return sorted(providers) # type: ignore[arg-type] diff --git a/src/fast_agent/llm/provider/bedrock/llm_bedrock.py b/src/fast_agent/llm/provider/bedrock/llm_bedrock.py index cb1795472..a19cff847 100644 --- a/src/fast_agent/llm/provider/bedrock/llm_bedrock.py +++ b/src/fast_agent/llm/provider/bedrock/llm_bedrock.py @@ -4,7 +4,7 @@ import sys from dataclasses import dataclass from enum import Enum, auto -from typing import TYPE_CHECKING, Any, Dict, List, Tuple, Type, Union +from typing import TYPE_CHECKING, Any, Type, Union from mcp import Tool from mcp.types import ( @@ -68,8 +68,8 @@ class ReasoningEffort(Enum): } # Bedrock message format types -BedrockMessage = Dict[str, Any] # Bedrock message format -BedrockMessageParam = Dict[str, Any] # Bedrock message parameter format +BedrockMessage = dict[str, Any] # Bedrock message format +BedrockMessageParam = dict[str, Any] # Bedrock message parameter format class ToolSchemaType(Enum): @@ -132,7 +132,7 @@ class BedrockLLM(FastAgentLLM[BedrockMessageParam, BedrockMessage]): """ # Class-level capabilities cache shared across all instances - capabilities: Dict[str, ModelCapabilities] = {} + capabilities: dict[str, ModelCapabilities] = {} @classmethod def debug_cache(cls) -> None: @@ -281,8 +281,8 @@ def _get_bedrock_runtime_client(self): return self._bedrock_runtime_client def _convert_extended_messages_to_provider( - self, messages: List[PromptMessageExtended] - ) -> List[BedrockMessageParam]: + self, messages: list[PromptMessageExtended] + ) -> list[BedrockMessageParam]: """ Convert PromptMessageExtended list to Bedrock BedrockMessageParam format. This is called fresh on every API call from _convert_to_provider_format(). @@ -293,7 +293,7 @@ def _convert_extended_messages_to_provider( Returns: List of Bedrock BedrockMessageParam objects """ - converted: List[BedrockMessageParam] = [] + converted: list[BedrockMessageParam] = [] for msg in messages: bedrock_msg = BedrockConverter.convert_to_bedrock(msg) converted.append(bedrock_msg) @@ -301,7 +301,7 @@ def _convert_extended_messages_to_provider( def _build_tool_name_mapping( self, tools: "ListToolsResult", name_policy: ToolNamePolicy - ) -> Dict[str, str]: + ) -> dict[str, str]: """Build tool name mapping based on schema type and name policy. Returns dict mapping from converted_name -> original_name for tool execution. @@ -324,8 +324,8 @@ def _build_tool_name_mapping( return mapping def _convert_tools_nova_format( - self, tools: "ListToolsResult", tool_name_mapping: Dict[str, str] - ) -> List[Dict[str, Any]]: + self, tools: "ListToolsResult", tool_name_mapping: dict[str, str] + ) -> list[dict[str, Any]]: """Convert MCP tools to Nova-specific toolSpec format. Note: Nova models have VERY strict JSON schema requirements: @@ -347,14 +347,14 @@ def _convert_tools_nova_format( # Create Nova-compliant schema with ONLY the three allowed fields # Always include type and properties (even if empty) - nova_schema: Dict[str, Any] = {"type": "object", "properties": {}} + nova_schema: dict[str, Any] = {"type": "object", "properties": {}} # Properties - clean them strictly - properties: Dict[str, Any] = {} + properties: dict[str, Any] = {} if "properties" in input_schema and isinstance(input_schema["properties"], dict): for prop_name, prop_def in input_schema["properties"].items(): # Only include type and description for each property - clean_prop: Dict[str, Any] = {} + clean_prop: dict[str, Any] = {} if isinstance(prop_def, dict): # Only include type (required) and description (optional) @@ -408,7 +408,7 @@ def _convert_tools_nova_format( return bedrock_tools def _convert_tools_system_prompt_format( - self, tools: "ListToolsResult", tool_name_mapping: Dict[str, str] + self, tools: "ListToolsResult", tool_name_mapping: dict[str, str] ) -> str: """Convert MCP tools to system prompt format.""" if not tools.tools: @@ -460,8 +460,8 @@ def _convert_tools_system_prompt_format( return system_prompt def _convert_tools_anthropic_format( - self, tools: "ListToolsResult", tool_name_mapping: Dict[str, str] - ) -> List[Dict[str, Any]]: + self, tools: "ListToolsResult", tool_name_mapping: dict[str, str] + ) -> list[dict[str, Any]]: """Convert MCP tools to Anthropic format wrapped in Bedrock toolSpec - preserves raw schema.""" self.logger.debug( @@ -493,8 +493,8 @@ def _convert_tools_anthropic_format( return bedrock_tools def _parse_system_prompt_tool_response( - self, processed_response: Dict[str, Any], model: str - ) -> List[Dict[str, Any]]: + self, processed_response: dict[str, Any], model: str + ) -> list[dict[str, Any]]: """Parse system prompt tool response format: function calls in text.""" # Extract text content from the response text_content = "" @@ -722,8 +722,8 @@ def _parse_system_prompt_tool_response( return [] def _parse_anthropic_tool_response( - self, processed_response: Dict[str, Any] - ) -> List[Dict[str, Any]]: + self, processed_response: dict[str, Any] + ) -> list[dict[str, Any]]: """Parse Anthropic tool response format (same as native provider).""" tool_uses = [] @@ -743,8 +743,8 @@ def _parse_anthropic_tool_response( return tool_uses def _parse_tool_response( - self, processed_response: Dict[str, Any], model: str - ) -> List[Dict[str, Any]]: + self, processed_response: dict[str, Any], model: str + ) -> list[dict[str, Any]]: """Parse tool responses using cached schema, without model/family heuristics.""" caps = self.capabilities.get(model) or ModelCapabilities() schema = caps.schema @@ -762,7 +762,7 @@ def _parse_tool_response( if isinstance(c, dict) and "toolUse" in c ] if tool_uses: - parsed_tools: List[Dict[str, Any]] = [] + parsed_tools: list[dict[str, Any]] = [] for item in tool_uses: tu = item.get("toolUse", {}) if not isinstance(tu, dict): @@ -820,8 +820,8 @@ def _parse_tool_response( return [] def _build_tool_calls_dict( - self, parsed_tools: List[Dict[str, Any]] - ) -> Dict[str, CallToolRequest]: + self, parsed_tools: list[dict[str, Any]] + ) -> dict[str, CallToolRequest]: """ Convert parsed tools to CallToolRequest dict for external execution. @@ -950,8 +950,8 @@ def _convert_multipart_to_bedrock_message( return bedrock_msg def _convert_messages_to_bedrock( - self, messages: List[BedrockMessageParam] - ) -> List[Dict[str, Any]]: + self, messages: list[BedrockMessageParam] + ) -> list[dict[str, Any]]: """Convert message parameters to Bedrock format.""" bedrock_messages = [] for message in messages: @@ -1211,9 +1211,9 @@ async def _bedrock_completion( self, message_param: BedrockMessageParam, request_params: RequestParams | None = None, - tools: List[Tool] | None = None, - pre_messages: List[BedrockMessageParam] | None = None, - history: List[PromptMessageExtended] | None = None, + tools: list[Tool] | None = None, + pre_messages: list[BedrockMessageParam] | None = None, + history: list[PromptMessageExtended] | None = None, ) -> PromptMessageExtended: """ Process a query using Bedrock and available tools. @@ -1222,7 +1222,7 @@ async def _bedrock_completion( client = self._get_bedrock_runtime_client() try: - messages: List[BedrockMessageParam] = list(pre_messages) if pre_messages else [] + messages: list[BedrockMessageParam] = list(pre_messages) if pre_messages else [] params = self.get_request_params(request_params) except (ClientError, BotoCoreError) as e: error_msg = str(e) @@ -1265,7 +1265,7 @@ async def _bedrock_completion( tool_list = ListToolsResult(tools=tools) - response_content_blocks: List[ContentBlock] = [] + response_content_blocks: list[ContentBlock] = [] model = self.default_request_params.model # Single API call - no tool execution loop @@ -1324,7 +1324,7 @@ async def _bedrock_completion( converse_args = {"modelId": model, "messages": [dict(m) for m in bedrock_messages]} # Build tools representation for this schema - tools_payload: Union[List[Dict[str, Any]], str, None] = None + tools_payload: Union[list[dict[str, Any]], str, None] = None # Get tool name policy (needed even when no tools for cache logic) name_policy = ( self.capabilities.get(model) or ModelCapabilities() @@ -1450,7 +1450,7 @@ async def _bedrock_completion( converse_args["toolConfig"] = {"tools": tools_payload} # Inference configuration and overrides - inference_config: Dict[str, Any] = {} + inference_config: dict[str, Any] = {} if params.maxTokens is not None: inference_config["maxTokens"] = params.maxTokens if params.stopSequences: @@ -1833,7 +1833,7 @@ async def _bedrock_completion( self.capabilities[model] = caps_tmp # NEW: Handle tool calls without execution - return them for external handling - tool_calls: Dict[str, CallToolRequest] | None = None + tool_calls: dict[str, CallToolRequest] | None = None if stop_reason in ["tool_use", "tool_calls"]: parsed_tools = self._parse_tool_response(processed_response, model) if parsed_tools: @@ -1857,9 +1857,9 @@ async def _bedrock_completion( async def _apply_prompt_provider_specific( self, - multipart_messages: List[PromptMessageExtended], + multipart_messages: list[PromptMessageExtended], request_params: RequestParams | None = None, - tools: List[Tool] | None = None, + tools: list[Tool] | None = None, is_template: bool = False, ) -> PromptMessageExtended: """ @@ -1936,7 +1936,7 @@ def get_field_type_representation(field_type: Any) -> Any: else: return "any" - def _generate_schema_dict(model_class: Type) -> Dict[str, Any]: + def _generate_schema_dict(model_class: Type) -> dict[str, Any]: """Recursively generate the schema as a dictionary.""" schema_dict = {} if hasattr(model_class, "model_fields"): @@ -1949,10 +1949,10 @@ def _generate_schema_dict(model_class: Type) -> Dict[str, Any]: async def _apply_prompt_provider_specific_structured( self, - multipart_messages: List[PromptMessageExtended], + multipart_messages: list[PromptMessageExtended], model: Type[ModelT], request_params: RequestParams | None = None, - ) -> Tuple[ModelT | None, PromptMessageExtended]: + ) -> tuple[ModelT | None, PromptMessageExtended]: """Apply structured output for Bedrock using prompt engineering with a simplified schema.""" # Short-circuit: if the last message is already an assistant JSON payload, # parse it directly without invoking the model. This restores pre-regression behavior @@ -2123,7 +2123,7 @@ def _clean_json_response(self, text: str) -> str: def _structured_from_multipart( self, message: PromptMessageExtended, model: Type[ModelT] - ) -> Tuple[ModelT | None, PromptMessageExtended]: + ) -> tuple[ModelT | None, PromptMessageExtended]: """Override to apply JSON cleaning before parsing.""" # Get the text from the multipart message text = message.all_text() diff --git a/src/fast_agent/llm/provider/bedrock/multipart_converter_bedrock.py b/src/fast_agent/llm/provider/bedrock/multipart_converter_bedrock.py index f567bd2c7..24c7b6179 100644 --- a/src/fast_agent/llm/provider/bedrock/multipart_converter_bedrock.py +++ b/src/fast_agent/llm/provider/bedrock/multipart_converter_bedrock.py @@ -1,9 +1,9 @@ -from typing import Any, Dict +from typing import Any from fast_agent.types import PromptMessageExtended # Bedrock message format types -BedrockMessageParam = Dict[str, Any] +BedrockMessageParam = dict[str, Any] class BedrockConverter: diff --git a/src/fast_agent/llm/provider/google/google_converter.py b/src/fast_agent/llm/provider/google/google_converter.py index 72e281efa..3f2db0435 100644 --- a/src/fast_agent/llm/provider/google/google_converter.py +++ b/src/fast_agent/llm/provider/google/google_converter.py @@ -1,5 +1,5 @@ import base64 -from typing import Any, Dict, List, Tuple +from typing import Any # Import necessary types from google.genai from google.genai import types @@ -30,7 +30,7 @@ class GoogleConverter: Converts between fast-agent and google.genai data structures. """ - def _clean_schema_for_google(self, schema: Dict[str, Any]) -> Dict[str, Any]: + def _clean_schema_for_google(self, schema: dict[str, Any]) -> dict[str, Any]: """ Recursively removes unsupported JSON schema keywords for google.genai.types.Schema. Specifically removes 'additionalProperties', '$schema', 'exclusiveMaximum', and 'exclusiveMinimum'. @@ -79,7 +79,7 @@ def _clean_schema_for_google(self, schema: Dict[str, Any]) -> Dict[str, Any]: cleaned_schema[key] = value return cleaned_schema - def _resolve_refs(self, schema: Dict[str, Any], root_schema: Dict[str, Any]) -> Dict[str, Any]: + def _resolve_refs(self, schema: dict[str, Any], root_schema: dict[str, Any]) -> dict[str, Any]: """ Resolve $ref references in a JSON schema by inlining the referenced definitions. @@ -128,15 +128,15 @@ def _resolve_refs(self, schema: Dict[str, Any], root_schema: Dict[str, Any]) -> return resolved def convert_to_google_content( - self, messages: List[PromptMessageExtended] - ) -> List[types.Content]: + self, messages: list[PromptMessageExtended] + ) -> list[types.Content]: """ Converts a list of fast-agent PromptMessageExtended to google.genai types.Content. Handles different roles and content types (text, images, etc.). """ - google_contents: List[types.Content] = [] + google_contents: list[types.Content] = [] for message in messages: - parts: List[types.Part] = [] + parts: list[types.Part] = [] for part_content in message.content: # renamed part to part_content to avoid conflict if is_text_content(part_content): parts.append(types.Part.from_text(text=get_text(part_content) or "")) @@ -184,11 +184,11 @@ def convert_to_google_content( google_contents.append(types.Content(role=google_role, parts=parts)) return google_contents - def convert_to_google_tools(self, tools: List[Tool]) -> List[types.Tool]: + def convert_to_google_tools(self, tools: list[Tool]) -> list[types.Tool]: """ Converts a list of fast-agent ToolDefinition to google.genai types.Tool. """ - google_tools: List[types.Tool] = [] + google_tools: list[types.Tool] = [] for tool in tools: cleaned_input_schema = self._clean_schema_for_google(tool.inputSchema) function_declaration = types.FunctionDeclaration( @@ -201,12 +201,12 @@ def convert_to_google_tools(self, tools: List[Tool]) -> List[types.Tool]: def convert_from_google_content( self, content: types.Content - ) -> List[ContentBlock | CallToolRequestParams]: + ) -> list[ContentBlock | CallToolRequestParams]: """ Converts google.genai types.Content from a model response to a list of fast-agent content types or tool call requests. """ - fast_agent_parts: List[ContentBlock | CallToolRequestParams] = [] + fast_agent_parts: list[ContentBlock | CallToolRequestParams] = [] if content is None or not hasattr(content, "parts") or content.parts is None: return [] # Google API response 'content' object is None. Cannot extract parts. @@ -238,17 +238,17 @@ def convert_from_google_function_call( ) def convert_function_results_to_google( - self, tool_results: List[Tuple[str, CallToolResult]] - ) -> List[types.Content]: + self, tool_results: list[tuple[str, CallToolResult]] + ) -> list[types.Content]: """ Converts a list of fast-agent tool results to google.genai types.Content with role 'tool'. Handles multimodal content in tool results. """ - google_tool_response_contents: List[types.Content] = [] + google_tool_response_contents: list[types.Content] = [] for tool_name, tool_result in tool_results: - current_content_parts: List[types.Part] = [] - textual_outputs: List[str] = [] - media_parts: List[types.Part] = [] + current_content_parts: list[types.Part] = [] + textual_outputs: list[str] = [] + media_parts: list[types.Part] = [] for item in tool_result.content: if is_text_content(item): @@ -294,7 +294,7 @@ def convert_function_results_to_google( ) # Add handling for other content types if needed, for now they are skipped or become unhandled resource text - function_response_payload: Dict[str, Any] = {"tool_name": tool_name} + function_response_payload: dict[str, Any] = {"tool_name": tool_name} if textual_outputs: function_response_payload["text_content"] = "\n".join(textual_outputs) @@ -325,7 +325,7 @@ def convert_request_params_to_google_config( """ Converts fast-agent RequestParams to google.genai types.GenerateContentConfig. """ - config_args: Dict[str, Any] = {} + config_args: dict[str, Any] = {} if request_params.temperature is not None: config_args["temperature"] = request_params.temperature if request_params.maxTokens is not None: @@ -351,8 +351,8 @@ def convert_request_params_to_google_config( return types.GenerateContentConfig(**config_args) def convert_from_google_content_list( - self, contents: List[types.Content] - ) -> List[PromptMessageExtended]: + self, contents: list[types.Content] + ) -> list[PromptMessageExtended]: """ Converts a list of google.genai types.Content to a list of fast-agent PromptMessageExtended. """ @@ -369,7 +369,7 @@ def _convert_from_google_content(self, content: types.Content) -> PromptMessageE if content.role == "model" and any(part.function_call for part in content.parts): return PromptMessageExtended(role="assistant", content=[]) - fast_agent_parts: List[ContentBlock | CallToolRequestParams] = [] + fast_agent_parts: list[ContentBlock | CallToolRequestParams] = [] for part in content.parts: if part.text: fast_agent_parts.append(TextContent(type="text", text=part.text)) diff --git a/src/fast_agent/llm/provider/google/llm_google_native.py b/src/fast_agent/llm/provider/google/llm_google_native.py index cafbc523a..756f2925f 100644 --- a/src/fast_agent/llm/provider/google/llm_google_native.py +++ b/src/fast_agent/llm/provider/google/llm_google_native.py @@ -1,6 +1,5 @@ import json import secrets -from typing import Dict, List # Import necessary types and client from google.genai from google import genai @@ -112,7 +111,7 @@ async def _stream_generate_content( self, *, model: str, - contents: List[types.Content], + contents: list[types.Content], config: types.GenerateContentConfig, client: genai.Client, ) -> types.GenerateContentResponse | None: @@ -145,8 +144,8 @@ async def _consume_google_stream( ) -> types.GenerateContentResponse | None: """Consume the async streaming iterator and aggregate the final response.""" estimated_tokens = 0 - timeline: List[tuple[str, int | None, str]] = [] - tool_streams: Dict[int, Dict[str, str]] = {} + timeline: list[tuple[str, int | None, str]] = [] + tool_streams: dict[int, dict[str, str]] = {} active_tool_index: int | None = None tool_counter = 0 usage_metadata = None @@ -281,7 +280,7 @@ async def _consume_google_stream( if not timeline and last_chunk is None: return None - final_parts: List[types.Part] = [] + final_parts: list[types.Part] = [] for entry_type, index, payload in timeline: if entry_type == "text": final_parts.append(types.Part.from_text(text=payload)) @@ -322,9 +321,9 @@ async def _consume_google_stream( async def _google_completion( self, - message: List[types.Content] | None, + message: list[types.Content] | None, request_params: RequestParams | None = None, - tools: List[McpTool] | None = None, + tools: list[McpTool] | None = None, *, response_mime_type: str | None = None, response_schema: object | None = None, @@ -333,15 +332,15 @@ async def _google_completion( Process a query using Google's generate_content API and available tools. """ request_params = self.get_request_params(request_params=request_params) - responses: List[ContentBlock] = [] + responses: list[ContentBlock] = [] # Caller supplies the full set of messages to send (history + turn) - conversation_history: List[types.Content] = list(message or []) + conversation_history: list[types.Content] = list(message or []) self.logger.debug(f"Google completion requested with messages: {conversation_history}") self._log_chat_progress(self.chat_turn(), model=request_params.model) - available_tools: List[types.Tool] = ( + available_tools: list[types.Tool] = ( self._converter.convert_to_google_tools(tools or []) if tools else [] ) @@ -430,7 +429,7 @@ async def _google_completion( candidate.content ) stop_reason = LlmStopReason.END_TURN - tool_calls: Dict[str, CallToolRequest] | None = None + tool_calls: dict[str, CallToolRequest] | None = None # Add model's response to the working conversation history for this turn conversation_history.append(candidate.content) @@ -475,9 +474,9 @@ async def _google_completion( async def _apply_prompt_provider_specific( self, - multipart_messages: List[PromptMessageExtended], + multipart_messages: list[PromptMessageExtended], request_params: RequestParams | None = None, - tools: List[McpTool] | None = None, + tools: list[McpTool] | None = None, is_template: bool = False, ) -> PromptMessageExtended: """ @@ -495,13 +494,13 @@ async def _apply_prompt_provider_specific( # Build the provider-native message list for this turn from the last user message # This must handle tool results as function responses before any additional user content. - turn_messages: List[types.Content] = [] + turn_messages: list[types.Content] = [] # 1) Convert tool results (if any) to google function responses if last_message.tool_results: # Map correlation IDs back to tool names using the last assistant tool_calls # found in our high-level message history - id_to_name: Dict[str, str] = {} + id_to_name: dict[str, str] = {} for prev in reversed(multipart_messages): if prev.role == "assistant" and prev.tool_calls: for call_id, call in prev.tool_calls.items(): @@ -531,7 +530,7 @@ async def _apply_prompt_provider_specific( if not turn_messages: turn_messages.append(types.Content(role="user", parts=[types.Part.from_text("")])) - conversation_history: List[types.Content] = [] + conversation_history: list[types.Content] = [] if request_params.use_history and len(multipart_messages) > 1: conversation_history.extend( self._convert_to_provider_format(multipart_messages[:-1]) @@ -541,8 +540,8 @@ async def _apply_prompt_provider_specific( return await self._google_completion(conversation_history, request_params=request_params, tools=tools) def _convert_extended_messages_to_provider( - self, messages: List[PromptMessageExtended] - ) -> List[types.Content]: + self, messages: list[PromptMessageExtended] + ) -> list[types.Content]: """ Convert PromptMessageExtended list to Google types.Content format. This is called fresh on every API call from _convert_to_provider_format(). @@ -630,7 +629,7 @@ async def _apply_prompt_provider_specific_structured( response_schema = model if schema is None else schema # Convert the last user message to provider-native content for the current turn - turn_messages: List[types.Content] = [] + turn_messages: list[types.Content] = [] if last_message: turn_messages = self._converter.convert_to_google_content([last_message]) diff --git a/src/fast_agent/llm/provider/openai/llm_openai.py b/src/fast_agent/llm/provider/openai/llm_openai.py index 0534d70a2..20e4a6928 100644 --- a/src/fast_agent/llm/provider/openai/llm_openai.py +++ b/src/fast_agent/llm/provider/openai/llm_openai.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, List +from typing import Any from mcp import Tool from mcp.types import ( @@ -677,9 +677,9 @@ async def _process_stream_manual(self, stream, model: str): async def _openai_completion( self, - message: List[OpenAIMessage] | None, + message: list[OpenAIMessage] | None, request_params: RequestParams | None = None, - tools: List[Tool] | None = None, + tools: list[Tool] | None = None, ) -> PromptMessageExtended: """ Process a query using an LLM and available tools. @@ -689,11 +689,11 @@ async def _openai_completion( request_params = self.get_request_params(request_params=request_params) - response_content_blocks: List[ContentBlock] = [] + response_content_blocks: list[ContentBlock] = [] model_name = self.default_request_params.model or DEFAULT_OPENAI_MODEL # TODO -- move this in to agent context management / agent group handling - messages: List[ChatCompletionMessageParam] = [] + messages: list[ChatCompletionMessageParam] = [] system_prompt = self.instruction or request_params.systemPrompt if system_prompt: messages.append(ChatCompletionSystemMessageParam(role="system", content=system_prompt)) @@ -702,7 +702,7 @@ async def _openai_completion( if message: messages.extend(message) - available_tools: List[ChatCompletionToolParam] | None = [ + available_tools: list[ChatCompletionToolParam] | None = [ { "type": "function", "function": { @@ -797,7 +797,7 @@ async def _openai_completion( messages.append(message_dict) stop_reason = LlmStopReason.END_TURN - requested_tool_calls: Dict[str, CallToolRequest] | None = None + requested_tool_calls: dict[str, CallToolRequest] | None = None if await self._is_tool_stop_reason(choice.finish_reason) and message.tool_calls: requested_tool_calls = {} stop_reason = LlmStopReason.TOOL_USE @@ -883,9 +883,9 @@ async def _is_tool_stop_reason(self, finish_reason: str) -> bool: async def _apply_prompt_provider_specific( self, - multipart_messages: List["PromptMessageExtended"], + multipart_messages: list["PromptMessageExtended"], request_params: RequestParams | None = None, - tools: List[Tool] | None = None, + tools: list[Tool] | None = None, is_template: bool = False, ) -> PromptMessageExtended: """ @@ -909,7 +909,7 @@ async def _apply_prompt_provider_specific( return await self._openai_completion(converted_messages, req_params, tools) def _prepare_api_request( - self, messages, tools: List[ChatCompletionToolParam] | None, request_params: RequestParams + self, messages, tools: list[ChatCompletionToolParam] | None, request_params: RequestParams ) -> dict[str, str]: # Create base arguments dictionary @@ -934,14 +934,14 @@ def _prepare_api_request( if tools: base_args["parallel_tool_calls"] = request_params.parallel_tool_calls - arguments: Dict[str, str] = self.prepare_provider_arguments( + arguments: dict[str, str] = self.prepare_provider_arguments( base_args, request_params, self.OPENAI_EXCLUDE_FIELDS.union(self.BASE_EXCLUDE_FIELDS) ) return arguments def _convert_extended_messages_to_provider( - self, messages: List[PromptMessageExtended] - ) -> List[ChatCompletionMessageParam]: + self, messages: list[PromptMessageExtended] + ) -> list[ChatCompletionMessageParam]: """ Convert PromptMessageExtended list to OpenAI ChatCompletionMessageParam format. This is called fresh on every API call from _convert_to_provider_format(). @@ -952,7 +952,7 @@ def _convert_extended_messages_to_provider( Returns: List of OpenAI ChatCompletionMessageParam objects """ - converted: List[ChatCompletionMessageParam] = [] + converted: list[ChatCompletionMessageParam] = [] for msg in messages: # convert_to_openai returns a list of messages @@ -960,7 +960,7 @@ def _convert_extended_messages_to_provider( return converted - def adjust_schema(self, inputSchema: Dict) -> Dict: + def adjust_schema(self, inputSchema: dict) -> dict: # return inputSchema if self.provider not in [Provider.OPENAI, Provider.AZURE]: return inputSchema diff --git a/src/fast_agent/llm/provider/openai/llm_openai_compatible.py b/src/fast_agent/llm/provider/openai/llm_openai_compatible.py index ce1b0c9bd..ada76f5a1 100644 --- a/src/fast_agent/llm/provider/openai/llm_openai_compatible.py +++ b/src/fast_agent/llm/provider/openai/llm_openai_compatible.py @@ -1,4 +1,4 @@ -from typing import List, Tuple, Type +from typing import Type from fast_agent.interfaces import ModelT from fast_agent.llm.model_database import ModelDatabase @@ -22,10 +22,10 @@ class OpenAICompatibleLLM(OpenAILLM): async def _apply_prompt_provider_specific_structured( self, - multipart_messages: List[PromptMessageExtended], + multipart_messages: list[PromptMessageExtended], model: Type[ModelT], request_params: RequestParams | None = None, - ) -> Tuple[ModelT | None, PromptMessageExtended]: + ) -> tuple[ModelT | None, PromptMessageExtended]: if not self._supports_structured_prompt(): return await super()._apply_prompt_provider_specific_structured( multipart_messages, model, request_params diff --git a/src/fast_agent/llm/provider/openai/llm_tensorzero_openai.py b/src/fast_agent/llm/provider/openai/llm_tensorzero_openai.py index 8bc82c145..2bc37122c 100644 --- a/src/fast_agent/llm/provider/openai/llm_tensorzero_openai.py +++ b/src/fast_agent/llm/provider/openai/llm_tensorzero_openai.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, List, Optional +from typing import Any from openai.types.chat import ChatCompletionMessageParam, ChatCompletionSystemMessageParam @@ -64,10 +64,10 @@ def _base_url(self) -> str: def _prepare_api_request( self, - messages: List[ChatCompletionMessageParam], - tools: Optional[List[Any]], + messages: list[ChatCompletionMessageParam], + tools: list[Any] | None, request_params: RequestParams, - ) -> Dict[str, Any]: + ) -> dict[str, Any]: """ Prepares the API request for the TensorZero OpenAI-compatible endpoint. This method injects system template variables and other TensorZero-specific diff --git a/src/fast_agent/llm/provider/openai/multipart_converter_openai.py b/src/fast_agent/llm/provider/openai/multipart_converter_openai.py index 6fc3e5813..9913b9853 100644 --- a/src/fast_agent/llm/provider/openai/multipart_converter_openai.py +++ b/src/fast_agent/llm/provider/openai/multipart_converter_openai.py @@ -1,5 +1,5 @@ import json -from typing import Any, Dict, List, Optional, Tuple, Union +from typing import Any, Union from mcp.types import ( CallToolResult, @@ -30,8 +30,8 @@ _logger = get_logger("multipart_converter_openai") # Define type aliases for content blocks -ContentBlock = Dict[str, Any] -OpenAIMessage = Dict[str, Any] +ContentBlock = dict[str, Any] +OpenAIMessage = dict[str, Any] class OpenAIConverter: @@ -55,7 +55,7 @@ def _is_supported_image_type(mime_type: str) -> bool: @staticmethod def convert_to_openai( multipart_msg: PromptMessageExtended, concatenate_text_blocks: bool = False - ) -> List[Dict[str, Any]]: + ) -> list[dict[str, Any]]: """ Convert a PromptMessageExtended message to OpenAI API format. @@ -70,7 +70,7 @@ def convert_to_openai( # assistant message with tool_calls per OpenAI format to establish the # required call IDs before tool responses appear. if multipart_msg.role == "assistant" and multipart_msg.tool_calls: - tool_calls_list: List[Dict[str, Any]] = [] + tool_calls_list: list[dict[str, Any]] = [] for tool_id, req in multipart_msg.tool_calls.items(): name = None arguments = {} @@ -122,7 +122,7 @@ def convert_to_openai( @staticmethod def _convert_content_to_message( content: list, role: str, concatenate_text_blocks: bool = False - ) -> Dict[str, Any] | None: + ) -> dict[str, Any] | None: """ Convert content blocks to a single OpenAI message. @@ -143,7 +143,7 @@ def _convert_content_to_message( return {"role": role, "content": get_text(content[0])} # For user messages, convert each content block - content_blocks: List[ContentBlock] = [] + content_blocks: list[ContentBlock] = [] _logger.debug(f"Converting {len(content)} content items for role '{role}'") @@ -195,7 +195,7 @@ def _convert_content_to_message( return result @staticmethod - def _concatenate_text_blocks(blocks: List[ContentBlock]) -> List[ContentBlock]: + def _concatenate_text_blocks(blocks: list[ContentBlock]) -> list[ContentBlock]: """ Combine adjacent text blocks into single blocks. @@ -208,7 +208,7 @@ def _concatenate_text_blocks(blocks: List[ContentBlock]) -> List[ContentBlock]: if not blocks: return [] - combined_blocks: List[ContentBlock] = [] + combined_blocks: list[ContentBlock] = [] current_text = "" for block in blocks: @@ -298,7 +298,7 @@ def _determine_mime_type(resource_content) -> str: @staticmethod def _convert_embedded_resource( resource: EmbeddedResource, - ) -> Optional[ContentBlock]: + ) -> ContentBlock | None: """ Convert EmbeddedResource to appropriate OpenAI content block. @@ -393,7 +393,7 @@ def _convert_embedded_resource( @staticmethod def _extract_text_from_content_blocks( - content: Union[str, List[ContentBlock]], + content: Union[str, list[ContentBlock]], ) -> str: """ Extract and combine text from content blocks. @@ -423,7 +423,7 @@ def convert_tool_result_to_openai( tool_result: CallToolResult, tool_call_id: str, concatenate_text_blocks: bool = False, - ) -> Union[Dict[str, Any], Tuple[Dict[str, Any], List[Dict[str, Any]]]]: + ) -> Union[dict[str, Any], tuple[dict[str, Any], list[dict[str, Any]]]]: """ Convert a CallToolResult to an OpenAI tool message. @@ -507,9 +507,9 @@ def convert_tool_result_to_openai( @staticmethod def convert_function_results_to_openai( - results: Dict[str, CallToolResult], + results: dict[str, CallToolResult], concatenate_text_blocks: bool = False, - ) -> List[Dict[str, Any]]: + ) -> list[dict[str, Any]]: """ Convert function call results to OpenAI messages. diff --git a/src/fast_agent/llm/provider/openai/openai_multipart.py b/src/fast_agent/llm/provider/openai/openai_multipart.py index 5856f9a1b..03293e439 100644 --- a/src/fast_agent/llm/provider/openai/openai_multipart.py +++ b/src/fast_agent/llm/provider/openai/openai_multipart.py @@ -4,7 +4,7 @@ Each function handles all content types consistently and is designed for simple testing. """ -from typing import Any, Dict, List, Union +from typing import Any, Union from mcp.types import ( BlobResourceContents, @@ -25,9 +25,9 @@ def openai_to_extended( message: Union[ ChatCompletionMessage, ChatCompletionMessageParam, - List[Union[ChatCompletionMessage, ChatCompletionMessageParam]], + list[Union[ChatCompletionMessage, ChatCompletionMessageParam]], ], -) -> Union[PromptMessageExtended, List[PromptMessageExtended]]: +) -> Union[PromptMessageExtended, list[PromptMessageExtended]]: """ Convert OpenAI messages to PromptMessageExtended format. @@ -43,7 +43,7 @@ def openai_to_extended( def _openai_message_to_extended( - message: Union[ChatCompletionMessage, Dict[str, Any]], + message: Union[ChatCompletionMessage, dict[str, Any]], ) -> PromptMessageExtended: """Convert a single OpenAI message to PromptMessageExtended.""" # Get role and content from message diff --git a/src/fast_agent/llm/provider/openai/openai_utils.py b/src/fast_agent/llm/provider/openai/openai_utils.py index ae46b660a..257d4eda0 100644 --- a/src/fast_agent/llm/provider/openai/openai_utils.py +++ b/src/fast_agent/llm/provider/openai/openai_utils.py @@ -5,7 +5,7 @@ delegating to the proper implementations in the providers/ directory. """ -from typing import Any, Dict, Union +from typing import Any, Union from openai.types.chat import ( ChatCompletionMessage, @@ -20,7 +20,7 @@ def openai_message_to_prompt_message_multipart( - message: Union[ChatCompletionMessage, Dict[str, Any]], + message: Union[ChatCompletionMessage, dict[str, Any]], ) -> PromptMessageExtended: """ Convert an OpenAI ChatCompletionMessage to a PromptMessageExtended. diff --git a/src/fast_agent/llm/provider/openai/responses.py b/src/fast_agent/llm/provider/openai/responses.py index ddc83b794..322abe0df 100644 --- a/src/fast_agent/llm/provider/openai/responses.py +++ b/src/fast_agent/llm/provider/openai/responses.py @@ -1,5 +1,4 @@ # from openai.types.beta.chat import -from typing import List from mcp import Tool from mcp.types import ContentBlock, TextContent @@ -49,9 +48,9 @@ async def _responses_client(self) -> AsyncOpenAI: async def _apply_prompt_provider_specific( self, - multipart_messages: List[PromptMessageExtended], + multipart_messages: list[PromptMessageExtended], request_params: RequestParams | None = None, - tools: List[Tool] | None = None, + tools: list[Tool] | None = None, is_template: bool = False, ) -> PromptMessageExtended: responses_client = await self._responses_client() @@ -83,7 +82,7 @@ async def _apply_prompt_provider_specific( ) final_response = await stream.get_final_response() - reasoning_content: List[ContentBlock] = [] + reasoning_content: list[ContentBlock] = [] for output_item in final_response.output: if isinstance(output_item, ResponseReasoningItem): summary_text = "\n".join(part.text for part in output_item.summary if part.text) diff --git a/src/fast_agent/llm/provider_key_manager.py b/src/fast_agent/llm/provider_key_manager.py index 8e93d0abd..9581853b8 100644 --- a/src/fast_agent/llm/provider_key_manager.py +++ b/src/fast_agent/llm/provider_key_manager.py @@ -4,19 +4,19 @@ """ import os -from typing import Any, Dict +from typing import Any from pydantic import BaseModel from fast_agent.core.exceptions import ProviderKeyError -PROVIDER_ENVIRONMENT_MAP: Dict[str, str] = { +PROVIDER_ENVIRONMENT_MAP: dict[str, str] = { # default behaviour in _get_env_key_name is to capitalize the # provider name and suffix "_API_KEY" - so no specific mapping needed unless overriding "hf": "HF_TOKEN", "responses": "OPENAI_API_KEY", # Temporary workaround } -PROVIDER_CONFIG_KEY_ALIASES: Dict[str, tuple[str, ...]] = { +PROVIDER_CONFIG_KEY_ALIASES: dict[str, tuple[str, ...]] = { # HuggingFace historically used "huggingface" (full name) in config files, # while the provider id is "hf". Support both spellings. "hf": ("hf", "huggingface"), diff --git a/src/fast_agent/llm/request_params.py b/src/fast_agent/llm/request_params.py index 314de5b60..59c1027ef 100644 --- a/src/fast_agent/llm/request_params.py +++ b/src/fast_agent/llm/request_params.py @@ -2,7 +2,7 @@ Request parameters definitions for LLM interactions. """ -from typing import Any, Dict, List +from typing import Any from mcp import SamplingMessage from mcp.types import CreateMessageRequestParams @@ -16,7 +16,7 @@ class RequestParams(CreateMessageRequestParams): Parameters to configure the FastAgentLLM 'generate' requests. """ - messages: List[SamplingMessage] = Field(exclude=True, default=[]) + messages: list[SamplingMessage] = Field(exclude=True, default=[]) """ Ignored. 'messages' are removed from CreateMessageRequestParams to avoid confusion with the 'message' parameter on 'generate' method. @@ -50,12 +50,12 @@ class RequestParams(CreateMessageRequestParams): Override response format for structured calls. Prefer sending pydantic model - only use in exceptional circumstances """ - template_vars: Dict[str, Any] = Field(default_factory=dict) + template_vars: dict[str, Any] = Field(default_factory=dict) """ Optional dictionary of template variables for dynamic templates. Currently only works for TensorZero inference backend """ - mcp_metadata: Dict[str, Any] | None = None + mcp_metadata: dict[str, Any] | None = None """ Metadata to pass through to MCP tool calls via the _meta field. """ diff --git a/src/fast_agent/llm/sampling_converter.py b/src/fast_agent/llm/sampling_converter.py index b1f21eb29..3c0671b16 100644 --- a/src/fast_agent/llm/sampling_converter.py +++ b/src/fast_agent/llm/sampling_converter.py @@ -3,7 +3,6 @@ This replaces the more complex provider-specific converters with direct conversions. """ -from typing import List, Optional from mcp.types import ( CreateMessageRequestParams, @@ -62,7 +61,7 @@ def extract_request_params(params: CreateMessageRequestParams) -> RequestParams: ) @staticmethod - def error_result(error_message: str, model: Optional[str] = None) -> CreateMessageResult: + def error_result(error_message: str, model: str | None = None) -> CreateMessageResult: """ Create an error result. @@ -82,8 +81,8 @@ def error_result(error_message: str, model: Optional[str] = None) -> CreateMessa @staticmethod def convert_messages( - messages: List[SamplingMessage], - ) -> List[PromptMessageExtended]: + messages: list[SamplingMessage], + ) -> list[PromptMessageExtended]: """ Convert multiple SamplingMessages to PromptMessageExtended objects. diff --git a/src/fast_agent/llm/usage_tracking.py b/src/fast_agent/llm/usage_tracking.py index 3f02cb22d..30f3a79a8 100644 --- a/src/fast_agent/llm/usage_tracking.py +++ b/src/fast_agent/llm/usage_tracking.py @@ -6,7 +6,7 @@ """ import time -from typing import List, Optional, Union +from typing import Union # Proper type imports for each provider try: @@ -48,7 +48,7 @@ class ModelContextWindows: """Context window sizes and cache configurations for various models""" @classmethod - def get_context_window(cls, model: str) -> Optional[int]: + def get_context_window(cls, model: str) -> int | None: return ModelDatabase.get_context_window(model) @@ -232,8 +232,8 @@ def from_fast_agent(cls, usage: FastAgentUsage, model: str) -> "TurnUsage": class UsageAccumulator(BaseModel): """Accumulates usage data across multiple turns with cache analytics""" - turns: List[TurnUsage] = Field(default_factory=list) - model: Optional[str] = None + turns: list[TurnUsage] = Field(default_factory=list) + model: str | None = None def add_turn(self, turn: TurnUsage) -> None: """Add a new turn to the accumulator""" @@ -315,7 +315,7 @@ def cumulative_reasoning_tokens(self) -> int: @computed_field @property - def cache_hit_rate(self) -> Optional[float]: + def cache_hit_rate(self) -> float | None: """Percentage of total input context served from cache""" cache_tokens = self.cumulative_cache_read_tokens + self.cumulative_cache_hit_tokens total_input_context = self.cumulative_input_tokens + cache_tokens @@ -333,7 +333,7 @@ def current_context_tokens(self) -> int: @computed_field @property - def context_window_size(self) -> Optional[int]: + def context_window_size(self) -> int | None: """Get context window size for current model""" if self.model: return ModelContextWindows.get_context_window(self.model) @@ -341,7 +341,7 @@ def context_window_size(self) -> Optional[int]: @computed_field @property - def context_usage_percentage(self) -> Optional[float]: + def context_usage_percentage(self) -> float | None: """Percentage of context window used""" window_size = self.context_window_size if window_size and window_size > 0: diff --git a/src/fast_agent/mcp/elicitation_factory.py b/src/fast_agent/mcp/elicitation_factory.py index e834123a6..24b4393e4 100644 --- a/src/fast_agent/mcp/elicitation_factory.py +++ b/src/fast_agent/mcp/elicitation_factory.py @@ -2,7 +2,7 @@ Factory for resolving elicitation handlers with proper precedence. """ -from typing import Any, Optional +from typing import Any from mcp.client.session import ElicitationFnT @@ -18,7 +18,7 @@ def resolve_elicitation_handler( agent_config: AgentConfig, app_config: Any, server_config: Any = None -) -> Optional[ElicitationFnT]: +) -> ElicitationFnT | None: """Resolve elicitation handler with proper precedence. Precedence order: diff --git a/src/fast_agent/mcp/helpers/content_helpers.py b/src/fast_agent/mcp/helpers/content_helpers.py index 8ad46369a..e567129c2 100644 --- a/src/fast_agent/mcp/helpers/content_helpers.py +++ b/src/fast_agent/mcp/helpers/content_helpers.py @@ -3,7 +3,7 @@ """ -from typing import TYPE_CHECKING, List, Optional, Sequence, Union +from typing import TYPE_CHECKING, Sequence, Union if TYPE_CHECKING: from fast_agent.mcp.prompt_message_extended import PromptMessageExtended @@ -21,7 +21,7 @@ ) -def get_text(content: ContentBlock) -> Optional[str]: +def get_text(content: ContentBlock) -> str | None: """Extract text content from a content object if available.""" if isinstance(content, TextContent): return content.text @@ -48,7 +48,7 @@ def get_text(content: ContentBlock) -> Optional[str]: return None -def get_image_data(content: ContentBlock) -> Optional[str]: +def get_image_data(content: ContentBlock) -> str | None: """Extract image data from a content object if available.""" if isinstance(content, ImageContent): return content.data @@ -60,7 +60,7 @@ def get_image_data(content: ContentBlock) -> Optional[str]: return None -def get_resource_uri(content: ContentBlock) -> Optional[str]: +def get_resource_uri(content: ContentBlock) -> str | None: """Extract resource URI from an EmbeddedResource if available.""" if isinstance(content, EmbeddedResource): return str(content.resource.uri) @@ -87,7 +87,7 @@ def is_resource_link(content: ContentBlock) -> bool: return isinstance(content, ResourceLink) -def get_resource_text(result: ReadResourceResult, index: int = 0) -> Optional[str]: +def get_resource_text(result: ReadResourceResult, index: int = 0) -> str | None: """Extract text content from a ReadResourceResult at the specified index.""" if index >= len(result.contents): raise IndexError( @@ -99,7 +99,7 @@ def get_resource_text(result: ReadResourceResult, index: int = 0) -> Optional[st return None -def split_thinking_content(message: str) -> tuple[Optional[str], str]: +def split_thinking_content(message: str) -> tuple[str | None, str]: """Split a message into thinking and content parts.""" import re @@ -127,8 +127,8 @@ def text_content(text: str) -> TextContent: def ensure_multipart_messages( - messages: List[Union["PromptMessageExtended", PromptMessage]], -) -> List["PromptMessageExtended"]: + messages: list[Union["PromptMessageExtended", PromptMessage]], +) -> list["PromptMessageExtended"]: """Ensure all messages in a list are PromptMessageExtended objects.""" # Import here to avoid circular dependency from fast_agent.mcp.prompt_message_extended import PromptMessageExtended @@ -153,7 +153,7 @@ def normalize_to_extended_list( "PromptMessageExtended", Sequence[Union[str, PromptMessage, "PromptMessageExtended"]], ], -) -> List["PromptMessageExtended"]: +) -> list["PromptMessageExtended"]: """Normalize various input types to a list of PromptMessageExtended objects.""" # Import here to avoid circular dependency from fast_agent.mcp.prompt_message_extended import PromptMessageExtended @@ -176,7 +176,7 @@ def normalize_to_extended_list( return [messages] # List of mixed types → convert each element - result: List[PromptMessageExtended] = [] + result: list[PromptMessageExtended] = [] for item in messages: if isinstance(item, str): result.append( diff --git a/src/fast_agent/mcp/helpers/server_config_helpers.py b/src/fast_agent/mcp/helpers/server_config_helpers.py index 79dc986b5..4d09eec59 100644 --- a/src/fast_agent/mcp/helpers/server_config_helpers.py +++ b/src/fast_agent/mcp/helpers/server_config_helpers.py @@ -1,12 +1,12 @@ """Helper functions for type-safe server config access.""" -from typing import TYPE_CHECKING, Any, Optional +from typing import TYPE_CHECKING, Any, Union if TYPE_CHECKING: from fast_agent.config import MCPServerSettings -def get_server_config(ctx: Any) -> Optional["MCPServerSettings"]: +def get_server_config(ctx: Any) -> Union["MCPServerSettings", None]: """Extract server config from context if available. Type guard helper that safely accesses server_config with proper type checking. diff --git a/src/fast_agent/mcp/hf_auth.py b/src/fast_agent/mcp/hf_auth.py index 1af560087..6ea9dd320 100644 --- a/src/fast_agent/mcp/hf_auth.py +++ b/src/fast_agent/mcp/hf_auth.py @@ -1,7 +1,6 @@ """HuggingFace authentication utilities for MCP connections.""" import os -from typing import Dict, Optional from urllib.parse import urlparse @@ -48,7 +47,7 @@ def is_huggingface_url(url: str) -> bool: return False -def get_hf_token_from_env() -> Optional[str]: +def get_hf_token_from_env() -> str | None: """ Get the HuggingFace token from the HF_TOKEN environment variable. @@ -58,7 +57,7 @@ def get_hf_token_from_env() -> Optional[str]: return os.environ.get("HF_TOKEN") -def should_add_hf_auth(url: str, existing_headers: Optional[Dict[str, str]]) -> bool: +def should_add_hf_auth(url: str, existing_headers: dict[str, str] | None) -> bool: """ Determine if HuggingFace authentication should be added to the headers. @@ -98,7 +97,7 @@ def should_add_hf_auth(url: str, existing_headers: Optional[Dict[str, str]]) -> return get_hf_token_from_env() is not None -def add_hf_auth_header(url: str, headers: Optional[Dict[str, str]]) -> Optional[Dict[str, str]]: +def add_hf_auth_header(url: str, headers: dict[str, str] | None) -> dict[str, str] | None: """ Add HuggingFace authentication header if appropriate. diff --git a/src/fast_agent/mcp/interfaces.py b/src/fast_agent/mcp/interfaces.py index 424a78f18..c7b190891 100644 --- a/src/fast_agent/mcp/interfaces.py +++ b/src/fast_agent/mcp/interfaces.py @@ -7,7 +7,6 @@ from typing import ( AsyncContextManager, Callable, - Optional, Protocol, runtime_checkable, ) @@ -44,16 +43,16 @@ class MCPConnectionManagerProtocol(Protocol): async def get_server( self, server_name: str, - client_session_factory: Optional[ + client_session_factory: Callable[ [ MemoryObjectReceiveStream, MemoryObjectSendStream, - Optional[timedelta], + timedelta | None, ], ClientSession, ] - ] = None, + | None = None, ) -> "ServerConnection": ... async def disconnect_server(self, server_name: str) -> None: ... @@ -71,16 +70,16 @@ def connection_manager(self) -> MCPConnectionManagerProtocol: ... def initialize_server( self, server_name: str, - client_session_factory: Optional[ + client_session_factory: Callable[ [ MemoryObjectReceiveStream, MemoryObjectSendStream, - Optional[timedelta], + timedelta | None, ], ClientSession, ] - ] = None, + | None = None, ) -> AsyncContextManager[ClientSession]: """Initialize a server and yield a client session.""" ... diff --git a/src/fast_agent/mcp/mcp_aggregator.py b/src/fast_agent/mcp/mcp_aggregator.py index 5e798e708..966c2e335 100644 --- a/src/fast_agent/mcp/mcp_aggregator.py +++ b/src/fast_agent/mcp/mcp_aggregator.py @@ -6,11 +6,9 @@ TYPE_CHECKING, Any, Callable, - Dict, - List, Mapping, - Optional, TypeVar, + Union, cast, ) @@ -92,7 +90,7 @@ class ServerStatus(BaseModel): last_call_at: datetime | None = None last_error_at: datetime | None = None staleness_seconds: float | None = None - call_counts: Dict[str, int] = Field(default_factory=dict) + call_counts: dict[str, int] = Field(default_factory=dict) error_message: str | None = None instructions_available: bool | None = None instructions_enabled: bool | None = None @@ -121,7 +119,7 @@ class MCPAggregator(ContextDependent): connection_persistence: bool = False """Whether to maintain a persistent connection to the server.""" - server_names: List[str] + server_names: list[str] """A list of server names to connect to.""" model_config = ConfigDict(extra="allow", arbitrary_types_allowed=True) @@ -160,11 +158,11 @@ async def __aexit__(self, exc_type, exc_val, exc_tb): def __init__( self, - server_names: List[str], + server_names: list[str], connection_persistence: bool = True, - context: Optional["Context"] = None, + context: Union["Context", None] = None, name: str | None = None, - config: Optional[Any] = None, # Accept the agent config for elicitation_handler access + config: Any | None = None, # Accept the agent config for elicitation_handler access tool_handler: ToolExecutionHandler | None = None, **kwargs, ) -> None: @@ -195,24 +193,24 @@ def __init__( logger = get_logger(logger_name) # Maps namespaced_tool_name -> namespaced tool info - self._namespaced_tool_map: Dict[str, NamespacedTool] = {} + self._namespaced_tool_map: dict[str, NamespacedTool] = {} # Maps server_name -> list of tools - self._server_to_tool_map: Dict[str, List[NamespacedTool]] = {} + self._server_to_tool_map: dict[str, list[NamespacedTool]] = {} self._tool_map_lock = Lock() # Cache for prompt objects, maps server_name -> list of prompt objects - self._prompt_cache: Dict[str, List[Prompt]] = {} + self._prompt_cache: dict[str, list[Prompt]] = {} self._prompt_cache_lock = Lock() # Lock for refreshing tools from a server self._refresh_lock = Lock() # Track runtime stats per server - self._server_stats: Dict[str, ServerStats] = {} + self._server_stats: dict[str, ServerStats] = {} self._stats_lock = Lock() # Track discovered Skybridge configurations per server - self._skybridge_configs: Dict[str, SkybridgeServerConfig] = {} + self._skybridge_configs: dict[str, SkybridgeServerConfig] = {} def _create_progress_callback( self, server_name: str, tool_name: str, tool_call_id: str @@ -266,7 +264,7 @@ async def close(self) -> None: @classmethod async def create( cls, - server_names: List[str], + server_names: list[str], connection_persistence: bool = False, ) -> "MCPAggregator": """ @@ -377,7 +375,7 @@ async def load_servers(self) -> None: }, ) - async def fetch_tools(server_name: str) -> List[Tool]: + async def fetch_tools(server_name: str) -> list[Tool]: # Only fetch tools if the server supports them if not await self.server_supports_feature(server_name, "tools"): logger.debug(f"Server '{server_name}' does not support tools") @@ -396,7 +394,7 @@ async def fetch_tools(server_name: str) -> List[Tool]: logger.error(f"Error loading tools from server '{server_name}'", data=e) return [] - async def fetch_prompts(server_name: str) -> List[Prompt]: + async def fetch_prompts(server_name: str) -> list[Prompt]: # Only fetch prompts if the server supports them if not await self.server_supports_feature(server_name, "prompts"): logger.debug(f"Server '{server_name}' does not support prompts") @@ -416,8 +414,8 @@ async def fetch_prompts(server_name: str) -> List[Prompt]: return [] async def load_server_data(server_name: str): - tools: List[Tool] = [] - prompts: List[Prompt] = [] + tools: list[Tool] = [] + prompts: list[Prompt] = [] # Use _execute_on_server for consistent tracking regardless of connection mode tools = await fetch_tools(server_name) @@ -504,7 +502,7 @@ async def _evaluate_skybridge_for_server( config = SkybridgeServerConfig(server_name=server_name) tool_entries = self._server_to_tool_map.get(server_name, []) - tool_configs: List[SkybridgeToolConfig] = [] + tool_configs: list[SkybridgeToolConfig] = [] for namespaced_tool in tool_entries: tool_meta = getattr(namespaced_tool.tool, "meta", None) or {} @@ -583,7 +581,7 @@ async def _evaluate_skybridge_for_server( continue contents = getattr(read_result, "contents", []) or [] - seen_mime_types: List[str] = [] + seen_mime_types: list[str] = [] for content in contents: mime_type = getattr(content, "mimeType", None) @@ -721,7 +719,7 @@ async def server_supports_feature(self, server_name: str, feature: str) -> bool: except Exception: # noqa: BLE001 return True - async def list_servers(self) -> List[str]: + async def list_servers(self) -> list[str]: """Return the list of server names aggregated by this agent.""" if not self.initialized: await self.load_servers() @@ -735,7 +733,7 @@ async def list_tools(self) -> ListToolsResult: if not self.initialized: await self.load_servers() - tools: List[Tool] = [] + tools: list[Tool] = [] for namespaced_tool_name, namespaced_tool in self._namespaced_tool_map.items(): tool_copy = namespaced_tool.tool.model_copy( @@ -815,7 +813,7 @@ async def _notify_stdio_transport_activity( "Failed to notify stdio transport activity for %s", server_name, exc_info=True ) - async def get_server_instructions(self) -> Dict[str, tuple[str, List[str]]]: + async def get_server_instructions(self) -> dict[str, tuple[str, list[str]]]: """ Get instructions from all connected servers along with their tool names. @@ -846,13 +844,13 @@ async def get_server_instructions(self) -> Dict[str, tuple[str, List[str]]]: return instructions - async def collect_server_status(self) -> Dict[str, ServerStatus]: + async def collect_server_status(self) -> dict[str, ServerStatus]: """Return aggregated status information for each configured server.""" if not self.initialized: await self.load_servers() now = datetime.now(timezone.utc) - status_map: Dict[str, ServerStatus] = {} + status_map: dict[str, ServerStatus] = {} for server_name in self.server_names: stats = self._server_stats.get(server_name) @@ -1017,7 +1015,7 @@ async def collect_server_status(self) -> Dict[str, ServerStatus]: return status_map - async def get_skybridge_configs(self) -> Dict[str, SkybridgeServerConfig]: + async def get_skybridge_configs(self) -> dict[str, SkybridgeServerConfig]: """Expose discovered Skybridge configurations keyed by server.""" if not self.initialized: await self.load_servers() @@ -1035,7 +1033,7 @@ async def _execute_on_server( operation_type: str, operation_name: str, method_name: str, - method_args: Dict[str, Any] = None, + method_args: dict[str, Any] = None, error_factory: Callable[[str], R] | None = None, progress_callback: ProgressFnT | None = None, ) -> R: @@ -1550,7 +1548,7 @@ async def get_prompt( async def list_prompts( self, server_name: str | None = None, agent_name: str | None = None - ) -> Mapping[str, List[Prompt]]: + ) -> Mapping[str, list[Prompt]]: """ List available prompts from one or all servers. @@ -1562,7 +1560,7 @@ async def list_prompts( if not self.initialized: await self.load_servers() - results: Dict[str, List[Prompt]] = {} + results: dict[str, list[Prompt]] = {} # If specific server requested if server_name: @@ -1821,7 +1819,7 @@ async def _get_resource_from_server( async def _list_resources_from_server( self, server_name: str, *, check_support: bool = True - ) -> List[Any]: + ) -> list[Any]: """ Internal helper method to list resources from a specific server. @@ -1845,7 +1843,7 @@ async def _list_resources_from_server( return getattr(result, "resources", []) or [] - async def list_resources(self, server_name: str | None = None) -> Dict[str, List[str]]: + async def list_resources(self, server_name: str | None = None) -> dict[str, list[str]]: """ List available resources from one or all servers. @@ -1859,7 +1857,7 @@ async def list_resources(self, server_name: str | None = None) -> Dict[str, List if not self.initialized: await self.load_servers() - results: Dict[str, List[str]] = {} + results: dict[str, list[str]] = {} # Get the list of servers to check servers_to_check = [server_name] if server_name else self.server_names @@ -1880,7 +1878,7 @@ async def list_resources(self, server_name: str | None = None) -> Dict[str, List try: resources = await self._list_resources_from_server(s_name, check_support=False) - formatted_resources: List[str] = [] + formatted_resources: list[str] = [] for resource in resources: uri = getattr(resource, "uri", None) if uri is not None: @@ -1891,7 +1889,7 @@ async def list_resources(self, server_name: str | None = None) -> Dict[str, List return results - async def list_mcp_tools(self, server_name: str | None = None) -> Dict[str, List[Tool]]: + async def list_mcp_tools(self, server_name: str | None = None) -> dict[str, list[Tool]]: """ List available tools from one or all servers, grouped by server name. @@ -1905,7 +1903,7 @@ async def list_mcp_tools(self, server_name: str | None = None) -> Dict[str, List if not self.initialized: await self.load_servers() - results: Dict[str, List[Tool]] = {} + results: dict[str, list[Tool]] = {} # Get the list of servers to check servers_to_check = [server_name] if server_name else self.server_names diff --git a/src/fast_agent/mcp/mcp_connection_manager.py b/src/fast_agent/mcp/mcp_connection_manager.py index 6416e8808..cd158f4c7 100644 --- a/src/fast_agent/mcp/mcp_connection_manager.py +++ b/src/fast_agent/mcp/mcp_connection_manager.py @@ -5,13 +5,7 @@ import asyncio import traceback from datetime import timedelta -from typing import ( - TYPE_CHECKING, - AsyncGenerator, - Callable, - Dict, - Optional, -) +from typing import TYPE_CHECKING, AsyncGenerator, Callable, Union from anyio import Event, Lock, create_task_group from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream @@ -69,7 +63,7 @@ def _add_none_to_context(context_manager): def _prepare_headers_and_auth( server_config: MCPServerSettings, -) -> tuple[dict[str, str], Optional["OAuthClientProvider"], set[str]]: +) -> tuple[dict[str, str], Union["OAuthClientProvider", None], set[str]]: """ Prepare request headers and determine if OAuth authentication should be used. @@ -359,11 +353,11 @@ class MCPConnectionManager(ContextDependent): """ def __init__( - self, server_registry: "ServerRegistry", context: Optional["Context"] = None + self, server_registry: "ServerRegistry", context: Union["Context", None] = None ) -> None: super().__init__(context=context) self.server_registry = server_registry - self.running_servers: Dict[str, ServerConnection] = {} + self.running_servers: dict[str, ServerConnection] = {} self._lock = Lock() # Manage our own task group - independent of task context self._task_group = None diff --git a/src/fast_agent/mcp/mcp_content.py b/src/fast_agent/mcp/mcp_content.py index b1ba2e8f0..761ff7735 100644 --- a/src/fast_agent/mcp/mcp_content.py +++ b/src/fast_agent/mcp/mcp_content.py @@ -7,7 +7,7 @@ import base64 from pathlib import Path -from typing import Any, List, Literal, Optional, Union +from typing import Any, Literal, Union from mcp.types import ( Annotations, @@ -54,7 +54,7 @@ def MCPText( def MCPImage( path: str | Path | None = None, data: bytes | None = None, - mime_type: Optional[str] = None, + mime_type: str | None = None, role: Literal["user", "assistant"] = "user", annotations: Annotations | None = None, ) -> dict: @@ -99,7 +99,7 @@ def MCPImage( def MCPFile( path: Union[str, Path], - mime_type: Optional[str] = None, + mime_type: str | None = None, role: Literal["user", "assistant"] = "user", annotations: Annotations | None = None, ) -> dict: @@ -152,7 +152,7 @@ def MCPFile( def MCPPrompt( *content_items: Union[dict, str, Path, bytes, ContentBlock, ReadResourceResult], role: Literal["user", "assistant"] = "user", -) -> List[dict]: +) -> list[dict]: """ Create one or more prompt messages with various content types. @@ -235,14 +235,14 @@ def MCPPrompt( def User( *content_items: Union[dict, str, Path, bytes, ContentBlock, ReadResourceResult], -) -> List[dict]: +) -> list[dict]: """Create user message(s) with various content types.""" return MCPPrompt(*content_items, role="user") def Assistant( *content_items: Union[dict, str, Path, bytes, ContentBlock, ReadResourceResult], -) -> List[dict]: +) -> list[dict]: """Create assistant message(s) with various content types.""" return MCPPrompt(*content_items, role="assistant") diff --git a/src/fast_agent/mcp/prompt.py b/src/fast_agent/mcp/prompt.py index ab0e5b9da..0e8031701 100644 --- a/src/fast_agent/mcp/prompt.py +++ b/src/fast_agent/mcp/prompt.py @@ -7,7 +7,7 @@ """ from pathlib import Path -from typing import Dict, List, Literal, Union +from typing import Literal, Union from mcp import CallToolRequest from mcp.types import ContentBlock, PromptMessage @@ -65,7 +65,7 @@ def assistant( str, Path, bytes, dict, ContentBlock, PromptMessage, PromptMessageExtended ], stop_reason: LlmStopReason | None = None, - tool_calls: Dict[str, CallToolRequest] | None = None, + tool_calls: dict[str, CallToolRequest] | None = None, ) -> PromptMessageExtended: """ Create an assistant PromptMessageExtended with various content items. @@ -126,7 +126,7 @@ def message( ) @classmethod - def conversation(cls, *messages) -> List[PromptMessage]: + def conversation(cls, *messages) -> list[PromptMessage]: """ Create a list of PromptMessages from various inputs. """ @@ -149,7 +149,7 @@ def conversation(cls, *messages) -> List[PromptMessage]: return result @classmethod - def from_multipart(cls, multipart: List[PromptMessageExtended]) -> List[PromptMessage]: + def from_multipart(cls, multipart: list[PromptMessageExtended]) -> list[PromptMessage]: """ Convert a list of PromptMessageExtended objects to PromptMessages. """ diff --git a/src/fast_agent/mcp/prompt_message_extended.py b/src/fast_agent/mcp/prompt_message_extended.py index 62a111f76..5ad16d9c2 100644 --- a/src/fast_agent/mcp/prompt_message_extended.py +++ b/src/fast_agent/mcp/prompt_message_extended.py @@ -1,4 +1,4 @@ -from typing import Dict, List, Mapping, Optional, Sequence +from typing import Mapping, Sequence from mcp.types import ( CallToolRequest, @@ -24,15 +24,15 @@ class PromptMessageExtended(BaseModel): """ role: Role - content: List[ContentBlock] = [] - tool_calls: Dict[str, CallToolRequest] | None = None - tool_results: Dict[str, CallToolResult] | None = None + content: list[ContentBlock] = [] + tool_calls: dict[str, CallToolRequest] | None = None + tool_results: dict[str, CallToolResult] | None = None channels: Mapping[str, Sequence[ContentBlock]] | None = None stop_reason: LlmStopReason | None = None is_template: bool = False @classmethod - def to_extended(cls, messages: List[PromptMessage]) -> List["PromptMessageExtended"]: + def to_extended(cls, messages: list[PromptMessage]) -> list["PromptMessageExtended"]: """Convert a sequence of PromptMessages into PromptMessageExtended objects.""" if not messages: return [] @@ -59,7 +59,7 @@ def to_extended(cls, messages: List[PromptMessage]) -> List["PromptMessageExtend return result - def from_multipart(self) -> List[PromptMessage]: + def from_multipart(self) -> list[PromptMessage]: """Convert this PromptMessageExtended to a sequence of standard PromptMessages.""" return [ PromptMessage(role=self.role, content=content_part) for content_part in self.content @@ -124,7 +124,7 @@ def add_text(self, to_add: str) -> TextContent: return text @classmethod - def parse_get_prompt_result(cls, result: GetPromptResult) -> List["PromptMessageExtended"]: + def parse_get_prompt_result(cls, result: GetPromptResult) -> list["PromptMessageExtended"]: """ Parse a GetPromptResult into PromptMessageExtended objects. @@ -138,8 +138,8 @@ def parse_get_prompt_result(cls, result: GetPromptResult) -> List["PromptMessage @classmethod def from_get_prompt_result( - cls, result: Optional[GetPromptResult] - ) -> List["PromptMessageExtended"]: + cls, result: GetPromptResult | None + ) -> list["PromptMessageExtended"]: """ Convert a GetPromptResult to PromptMessageExtended objects with error handling. This method safely handles None values and empty results. diff --git a/src/fast_agent/mcp/prompt_render.py b/src/fast_agent/mcp/prompt_render.py index b133fac70..ae9edab7f 100644 --- a/src/fast_agent/mcp/prompt_render.py +++ b/src/fast_agent/mcp/prompt_render.py @@ -2,7 +2,6 @@ Utilities for rendering PromptMessageExtended objects for display. """ -from typing import List from mcp.types import BlobResourceContents, TextResourceContents @@ -29,7 +28,7 @@ def render_multipart_message(message: PromptMessageExtended) -> str: Returns: A string representation of the message's content """ - rendered_parts: List[str] = [] + rendered_parts: list[str] = [] for content in message.content: if is_text_content(content): diff --git a/src/fast_agent/mcp/prompt_serialization.py b/src/fast_agent/mcp/prompt_serialization.py index 344f07a8d..2f6b1527d 100644 --- a/src/fast_agent/mcp/prompt_serialization.py +++ b/src/fast_agent/mcp/prompt_serialization.py @@ -17,7 +17,6 @@ """ import json -from typing import List from mcp.types import ( EmbeddedResource, @@ -59,7 +58,7 @@ def serialize_to_dict(obj, exclude_none: bool = True): def to_get_prompt_result( - messages: List[PromptMessageExtended], + messages: list[PromptMessageExtended], ) -> GetPromptResult: """ Convert PromptMessageExtended objects to a GetPromptResult container. @@ -80,7 +79,7 @@ def to_get_prompt_result( -def to_get_prompt_result_json(messages: List[PromptMessageExtended]) -> str: +def to_get_prompt_result_json(messages: list[PromptMessageExtended]) -> str: """ Convert PromptMessageExtended objects to MCP-compatible GetPromptResult JSON. @@ -98,7 +97,7 @@ def to_get_prompt_result_json(messages: List[PromptMessageExtended]) -> str: return json.dumps(result_dict, indent=2) -def to_json(messages: List[PromptMessageExtended]) -> str: +def to_json(messages: list[PromptMessageExtended]) -> str: """ Convert PromptMessageExtended objects directly to JSON, preserving all extended fields. @@ -121,7 +120,7 @@ def to_json(messages: List[PromptMessageExtended]) -> str: return json.dumps(result_dict, indent=2) -def from_json(json_str: str) -> List[PromptMessageExtended]: +def from_json(json_str: str) -> list[PromptMessageExtended]: """ Parse a JSON string into PromptMessageExtended objects. @@ -141,8 +140,8 @@ def from_json(json_str: str) -> List[PromptMessageExtended]: # Extract messages array messages_data = result_dict.get("messages", []) - extended_messages: List[PromptMessageExtended] = [] - basic_buffer: List[PromptMessage] = [] + extended_messages: list[PromptMessageExtended] = [] + basic_buffer: list[PromptMessage] = [] def flush_basic_buffer() -> None: nonlocal basic_buffer @@ -175,7 +174,7 @@ def flush_basic_buffer() -> None: return extended_messages -def save_json(messages: List[PromptMessageExtended], file_path: str) -> None: +def save_json(messages: list[PromptMessageExtended], file_path: str) -> None: """ Save PromptMessageExtended objects to a JSON file using enhanced format. @@ -192,7 +191,7 @@ def save_json(messages: List[PromptMessageExtended], file_path: str) -> None: f.write(json_str) -def load_json(file_path: str) -> List[PromptMessageExtended]: +def load_json(file_path: str) -> list[PromptMessageExtended]: """ Load PromptMessageExtended objects from a JSON file. @@ -210,7 +209,7 @@ def load_json(file_path: str) -> List[PromptMessageExtended]: return from_json(json_str) -def save_messages(messages: List[PromptMessageExtended], file_path: str) -> None: +def save_messages(messages: list[PromptMessageExtended], file_path: str) -> None: """ Save PromptMessageExtended objects to a file, with format determined by file extension. @@ -229,7 +228,7 @@ def save_messages(messages: List[PromptMessageExtended], file_path: str) -> None save_delimited(messages, file_path) -def load_messages(file_path: str) -> List[PromptMessageExtended]: +def load_messages(file_path: str) -> list[PromptMessageExtended]: """ Load PromptMessageExtended objects from a file, with format determined by file extension. @@ -255,12 +254,12 @@ def load_messages(file_path: str) -> List[PromptMessageExtended]: def multipart_messages_to_delimited_format( - messages: List[PromptMessageExtended], + messages: list[PromptMessageExtended], user_delimiter: str = USER_DELIMITER, assistant_delimiter: str = ASSISTANT_DELIMITER, resource_delimiter: str = RESOURCE_DELIMITER, combine_text: bool = True, # Set to False to maintain backward compatibility -) -> List[str]: +) -> list[str]: """ Convert PromptMessageExtended objects to a hybrid delimited format: - Plain text for user/assistant text content with delimiters @@ -338,7 +337,7 @@ def delimited_format_to_extended_messages( user_delimiter: str = USER_DELIMITER, assistant_delimiter: str = ASSISTANT_DELIMITER, resource_delimiter: str = RESOURCE_DELIMITER, -) -> List[PromptMessageExtended]: +) -> list[PromptMessageExtended]: """ Parse hybrid delimited format into PromptMessageExtended objects: - Plain text for user/assistant text content with delimiters @@ -522,7 +521,7 @@ def delimited_format_to_extended_messages( def save_delimited( - messages: List[PromptMessageExtended], + messages: list[PromptMessageExtended], file_path: str, user_delimiter: str = USER_DELIMITER, assistant_delimiter: str = ASSISTANT_DELIMITER, @@ -557,7 +556,7 @@ def load_delimited( user_delimiter: str = USER_DELIMITER, assistant_delimiter: str = ASSISTANT_DELIMITER, resource_delimiter: str = RESOURCE_DELIMITER, -) -> List[PromptMessageExtended]: +) -> list[PromptMessageExtended]: """ Load PromptMessageExtended objects from a file in hybrid delimited format. diff --git a/src/fast_agent/mcp/prompts/prompt_helpers.py b/src/fast_agent/mcp/prompts/prompt_helpers.py index e3e2dcb29..eaef2970b 100644 --- a/src/fast_agent/mcp/prompts/prompt_helpers.py +++ b/src/fast_agent/mcp/prompts/prompt_helpers.py @@ -5,7 +5,7 @@ without repetitive type checking. """ -from typing import List, Optional, Union, cast +from typing import Union, cast from mcp.types import ( EmbeddedResource, @@ -34,7 +34,7 @@ class MessageContent: """ @staticmethod - def get_all_text(message: Union[PromptMessage, "PromptMessageExtended"]) -> List[str]: + def get_all_text(message: Union[PromptMessage, "PromptMessageExtended"]) -> list[str]: """ Extract all text content from a message. @@ -73,7 +73,7 @@ def join_text( return separator.join(MessageContent.get_all_text(message)) @staticmethod - def get_first_text(message: Union[PromptMessage, "PromptMessageExtended"]) -> Optional[str]: + def get_first_text(message: Union[PromptMessage, "PromptMessageExtended"]) -> str | None: """ Get the first available text content from a message. @@ -114,7 +114,7 @@ def has_text_at_first_position(message: Union[PromptMessage, "PromptMessageExten @staticmethod def get_text_at_first_position( message: Union[PromptMessage, "PromptMessageExtended"], - ) -> Optional[str]: + ) -> str | None: """ Get the text from the first position of a message if it's TextContent. @@ -135,7 +135,7 @@ def get_text_at_first_position( return cast("TextContent", message.content[0]).text @staticmethod - def get_all_images(message: Union[PromptMessage, "PromptMessageExtended"]) -> List[str]: + def get_all_images(message: Union[PromptMessage, "PromptMessageExtended"]) -> list[str]: """ Extract all image data from a message. @@ -158,7 +158,7 @@ def get_all_images(message: Union[PromptMessage, "PromptMessageExtended"]) -> Li return result @staticmethod - def get_first_image(message: Union[PromptMessage, "PromptMessageExtended"]) -> Optional[str]: + def get_first_image(message: Union[PromptMessage, "PromptMessageExtended"]) -> str | None: """ Get the first available image data from a message. @@ -181,7 +181,7 @@ def get_first_image(message: Union[PromptMessage, "PromptMessageExtended"]) -> O @staticmethod def get_all_resources( message: Union[PromptMessage, "PromptMessageExtended"], - ) -> List[EmbeddedResource]: + ) -> list[EmbeddedResource]: """ Extract all embedded resources from a message. diff --git a/src/fast_agent/mcp/prompts/prompt_load.py b/src/fast_agent/mcp/prompts/prompt_load.py index 305a36edf..1b8b73e29 100644 --- a/src/fast_agent/mcp/prompts/prompt_load.py +++ b/src/fast_agent/mcp/prompts/prompt_load.py @@ -1,5 +1,5 @@ from pathlib import Path -from typing import List, Literal +from typing import Literal from mcp.server.fastmcp.prompts.base import ( AssistantMessage, @@ -31,8 +31,8 @@ def cast_message_role(role: str) -> MessageRole: def create_messages_with_resources( - content_sections: List[PromptContent], prompt_files: List[Path] -) -> List[PromptMessage]: + content_sections: list[PromptContent], prompt_files: list[Path] +) -> list[PromptMessage]: """ Create a list of messages from content sections, with resources properly handled. @@ -99,7 +99,7 @@ def create_resource_message( return message_class(content=embedded_resource) -def load_prompt(file: Path) -> List[PromptMessageExtended]: +def load_prompt(file: Path) -> list[PromptMessageExtended]: """ Load a prompt from a file and return as PromptMessageExtended objects. diff --git a/src/fast_agent/mcp/prompts/prompt_server.py b/src/fast_agent/mcp/prompts/prompt_server.py index a6fe9cb2d..c511622c4 100644 --- a/src/fast_agent/mcp/prompts/prompt_server.py +++ b/src/fast_agent/mcp/prompts/prompt_server.py @@ -11,7 +11,7 @@ import logging import sys from pathlib import Path -from typing import Any, Awaitable, Callable, Dict, List, Optional, Union +from typing import Any, Awaitable, Callable, Union from mcp.server.fastmcp import FastMCP from mcp.server.fastmcp.prompts.base import ( @@ -48,7 +48,7 @@ mcp = FastMCP("Prompt Server") -def convert_to_fastmcp_messages(prompt_messages: List[Union[PromptMessage, PromptMessageExtended]]) -> List[Message]: +def convert_to_fastmcp_messages(prompt_messages: list[Union[PromptMessage, PromptMessageExtended]]) -> list[Message]: """ Convert PromptMessage or PromptMessageExtended objects to FastMCP Message objects. This adapter prevents double-wrapping of messages and handles both types. @@ -90,7 +90,7 @@ def convert_to_fastmcp_messages(prompt_messages: List[Union[PromptMessage, Promp class PromptConfig(PromptMetadata): """Configuration for the prompt server""" - prompt_files: List[Path] = [] + prompt_files: list[Path] = [] user_delimiter: str = DEFAULT_USER_DELIMITER assistant_delimiter: str = DEFAULT_ASSISTANT_DELIMITER resource_delimiter: str = DEFAULT_RESOURCE_DELIMITER @@ -101,12 +101,12 @@ class PromptConfig(PromptMetadata): # We'll maintain registries of all exposed resources and prompts -exposed_resources: Dict[str, Path] = {} -prompt_registry: Dict[str, PromptMetadata] = {} +exposed_resources: dict[str, Path] = {} +prompt_registry: dict[str, PromptMetadata] = {} # Define a single type for prompt handlers to avoid mypy issues -PromptHandler = Callable[..., Awaitable[List[Message]]] +PromptHandler = Callable[..., Awaitable[list[Message]]] # Type for resource handler @@ -132,8 +132,8 @@ async def get_resource() -> str | bytes: def get_delimiter_config( - config: Optional[PromptConfig] = None, file_path: Optional[Path] = None -) -> Dict[str, Any]: + config: PromptConfig | None = None, file_path: Path | None = None +) -> dict[str, Any]: """Get delimiter configuration, falling back to defaults if config is None""" # Set defaults config_values = { @@ -153,7 +153,7 @@ def get_delimiter_config( return config_values -def register_prompt(file_path: Path, config: Optional[PromptConfig] = None) -> None: +def register_prompt(file_path: Path, config: PromptConfig | None = None) -> None: """Register a prompt file""" try: # Check if it's a JSON file for ultra-minimal path diff --git a/src/fast_agent/mcp/prompts/prompt_template.py b/src/fast_agent/mcp/prompts/prompt_template.py index 55b79c66b..a82324e3a 100644 --- a/src/fast_agent/mcp/prompts/prompt_template.py +++ b/src/fast_agent/mcp/prompts/prompt_template.py @@ -7,7 +7,7 @@ import re from pathlib import Path -from typing import Any, Dict, List, Literal, Optional, Set +from typing import Any, Literal from mcp.types import ( EmbeddedResource, @@ -33,8 +33,8 @@ class PromptMetadata(BaseModel): name: str description: str - template_variables: Set[str] = set() - resource_paths: List[str] = [] + template_variables: set[str] = set() + resource_paths: list[str] = [] file_path: Path @@ -47,7 +47,7 @@ class PromptContent(BaseModel): text: str role: str = "user" - resources: List[str] = [] + resources: list[str] = [] @field_validator("role") @classmethod @@ -57,7 +57,7 @@ def validate_role(cls, role: str) -> str: raise ValueError(f"Invalid role: {role}. Must be one of: user, assistant") return role - def apply_substitutions(self, context: Dict[str, Any]) -> "PromptContent": + def apply_substitutions(self, context: dict[str, Any]) -> "PromptContent": """Apply variable substitutions to the text and resources""" # Define placeholder pattern once to avoid repetition @@ -88,8 +88,8 @@ class PromptTemplate: def __init__( self, template_text: str, - delimiter_map: Optional[Dict[str, str]] = None, - template_file_path: Optional[Path] = None, + delimiter_map: dict[str, str] | None = None, + template_file_path: Path | None = None, ) -> None: """ Initialize a prompt template. @@ -108,8 +108,8 @@ def __init__( @classmethod def from_multipart_messages( cls, - messages: List[PromptMessageExtended], - delimiter_map: Optional[Dict[str, str]] = None, + messages: list[PromptMessageExtended], + delimiter_map: dict[str, str] | None = None, ) -> "PromptTemplate": """ Create a PromptTemplate from a list of PromptMessageExtended objects. @@ -143,16 +143,16 @@ def from_multipart_messages( return cls(content, delimiter_map) @property - def template_variables(self) -> Set[str]: + def template_variables(self) -> set[str]: """Get the template variables in this template""" return self._template_variables @property - def content_sections(self) -> List[PromptContent]: + def content_sections(self) -> list[PromptContent]: """Get the parsed content sections""" return self._parsed_content - def apply_substitutions(self, context: Dict[str, Any]) -> List[PromptContent]: + def apply_substitutions(self, context: dict[str, Any]) -> list[PromptContent]: """ Apply variable substitutions to the template. @@ -166,8 +166,8 @@ def apply_substitutions(self, context: Dict[str, Any]) -> List[PromptContent]: return [section.apply_substitutions(context) for section in self._parsed_content] def apply_substitutions_to_extended( - self, context: Dict[str, Any] - ) -> List[PromptMessageExtended]: + self, context: dict[str, Any] + ) -> list[PromptMessageExtended]: """ Apply variable substitutions to the template and return PromptMessageExtended objects. @@ -205,13 +205,13 @@ def apply_substitutions_to_extended( return multiparts - def _extract_template_variables(self, text: str) -> Set[str]: + def _extract_template_variables(self, text: str) -> set[str]: """Extract template variables from text using regex""" variable_pattern = r"{{([^}]+)}}" matches = re.findall(variable_pattern, text) return set(matches) - def to_extended_messages(self) -> List[PromptMessageExtended]: + def to_extended_messages(self) -> list[PromptMessageExtended]: """ Convert this template to a list of PromptMessageExtended objects. @@ -243,7 +243,7 @@ def to_extended_messages(self) -> List[PromptMessageExtended]: return multiparts - def _parse_template(self) -> List[PromptContent]: + def _parse_template(self) -> list[PromptContent]: """ Parse the template into sections based on delimiters. If no delimiters are found, treat the entire template as a single user message. @@ -324,7 +324,7 @@ class PromptTemplateLoader: Loads and processes prompt templates from files. """ - def __init__(self, delimiter_map: Optional[Dict[str, str]] = None) -> None: + def __init__(self, delimiter_map: dict[str, str] | None = None) -> None: """ Initialize the loader with optional custom delimiters. @@ -348,7 +348,7 @@ def load_from_file(self, file_path: Path) -> PromptTemplate: return PromptTemplate(content, self.delimiter_map, template_file_path=file_path) - def load_from_multipart(self, messages: List[PromptMessageExtended]) -> PromptTemplate: + def load_from_multipart(self, messages: list[PromptMessageExtended]) -> PromptTemplate: """ Create a PromptTemplate from a list of PromptMessageExtended objects. diff --git a/src/fast_agent/mcp/resource_utils.py b/src/fast_agent/mcp/resource_utils.py index 024d39caa..4f48fd525 100644 --- a/src/fast_agent/mcp/resource_utils.py +++ b/src/fast_agent/mcp/resource_utils.py @@ -1,6 +1,5 @@ import base64 from pathlib import Path -from typing import List, Optional, Tuple from mcp.types import ( BlobResourceContents, @@ -15,10 +14,10 @@ HTTP_TIMEOUT = 10 # Default timeout for HTTP requests # Define a type alias for resource content results -ResourceContent = Tuple[str, str, bool] +ResourceContent = tuple[str, str, bool] -def find_resource_file(resource_path: str, prompt_files: List[Path]) -> Optional[Path]: +def find_resource_file(resource_path: str, prompt_files: list[Path]) -> Path | None: """Find a resource file relative to one of the prompt files""" for prompt_file in prompt_files: potential_path = prompt_file.parent / resource_path @@ -27,7 +26,7 @@ def find_resource_file(resource_path: str, prompt_files: List[Path]) -> Optional return None -def load_resource_content(resource_path: str, prompt_files: List[Path]) -> ResourceContent: +def load_resource_content(resource_path: str, prompt_files: list[Path]) -> ResourceContent: """ Load a resource's content and determine its mime type diff --git a/src/fast_agent/mcp/server/agent_server.py b/src/fast_agent/mcp/server/agent_server.py index 771de6c49..d6d63fe06 100644 --- a/src/fast_agent/mcp/server/agent_server.py +++ b/src/fast_agent/mcp/server/agent_server.py @@ -8,7 +8,7 @@ import signal import time from contextlib import AsyncExitStack, asynccontextmanager -from typing import Awaitable, Callable, Set +from typing import Awaitable, Callable from mcp.server.fastmcp import Context as MCPContext from mcp.server.fastmcp import FastMCP @@ -56,7 +56,7 @@ def __init__( # Resource management self._exit_stack = AsyncExitStack() - self._active_connections: Set[any] = set() + self._active_connections: set[any] = set() # Server state self._server_task = None diff --git a/src/fast_agent/mcp/skybridge.py b/src/fast_agent/mcp/skybridge.py index 280c35c9a..cee0f240c 100644 --- a/src/fast_agent/mcp/skybridge.py +++ b/src/fast_agent/mcp/skybridge.py @@ -1,4 +1,3 @@ -from typing import List from pydantic import AnyUrl, BaseModel, Field @@ -34,9 +33,9 @@ class SkybridgeServerConfig(BaseModel): server_name: str supports_resources: bool = False - ui_resources: List[SkybridgeResourceConfig] = Field(default_factory=list) - warnings: List[str] = Field(default_factory=list) - tools: List[SkybridgeToolConfig] = Field(default_factory=list) + ui_resources: list[SkybridgeResourceConfig] = Field(default_factory=list) + warnings: list[str] = Field(default_factory=list) + tools: list[SkybridgeToolConfig] = Field(default_factory=list) @property def enabled(self) -> bool: diff --git a/src/fast_agent/mcp/ui_mixin.py b/src/fast_agent/mcp/ui_mixin.py index 051e580f3..98f8f84c7 100644 --- a/src/fast_agent/mcp/ui_mixin.py +++ b/src/fast_agent/mcp/ui_mixin.py @@ -7,7 +7,7 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Tuple +from typing import TYPE_CHECKING, Sequence from mcp.types import CallToolResult, ContentBlock, EmbeddedResource @@ -39,7 +39,7 @@ def __init__(self, *args, ui_mode: str = "auto", **kwargs): """Initialize the mixin with UI mode configuration.""" super().__init__(*args, **kwargs) self._ui_mode: str = ui_mode - self._pending_ui_resources: List[ContentBlock] = [] + self._pending_ui_resources: list[ContentBlock] = [] def set_ui_mode(self, mode: str) -> None: """ @@ -89,12 +89,12 @@ async def run_tools(self, request: "PromptMessageExtended") -> "PromptMessageExt async def show_assistant_message( self, message: "PromptMessageExtended", - bottom_items: List[str] | None = None, - highlight_items: str | List[str] | None = None, + bottom_items: list[str] | None = None, + highlight_items: str | list[str] | None = None, max_item_length: int | None = None, name: str | None = None, model: str | None = None, - additional_message: Optional["Text"] = None, + additional_message: "Text" | None = None, ) -> None: """Override to display UI resources after showing assistant message.""" # Show the assistant message normally via parent @@ -150,8 +150,8 @@ async def _display_ui_resources(self, resources: Sequence[ContentBlock]) -> None def _extract_ui_from_tool_results( self, - tool_results: Dict[str, CallToolResult], - ) -> Tuple[Dict[str, CallToolResult], List[ContentBlock]]: + tool_results: dict[str, CallToolResult], + ) -> tuple[dict[str, CallToolResult], list[ContentBlock]]: """ Extract UI resources from tool results. @@ -160,8 +160,8 @@ def _extract_ui_from_tool_results( if not tool_results: return tool_results, [] - extracted_ui: List[ContentBlock] = [] - new_results: Dict[str, CallToolResult] = {} + extracted_ui: list[ContentBlock] = [] + new_results: dict[str, CallToolResult] = {} for key, result in tool_results.items(): try: @@ -178,15 +178,15 @@ def _extract_ui_from_tool_results( return new_results, extracted_ui def _split_ui_blocks( - self, blocks: List[ContentBlock] - ) -> Tuple[List[ContentBlock], List[ContentBlock]]: + self, blocks: list[ContentBlock] + ) -> tuple[list[ContentBlock], list[ContentBlock]]: """ Split content blocks into UI and non-UI blocks. Returns tuple of (ui_blocks, other_blocks). """ - ui_blocks: List[ContentBlock] = [] - other_blocks: List[ContentBlock] = [] + ui_blocks: list[ContentBlock] = [] + other_blocks: list[ContentBlock] = [] for block in blocks or []: if self._is_ui_embedded_resource(block): diff --git a/src/fast_agent/mcp_server_registry.py b/src/fast_agent/mcp_server_registry.py index f4b5d320d..9d074c2af 100644 --- a/src/fast_agent/mcp_server_registry.py +++ b/src/fast_agent/mcp_server_registry.py @@ -7,7 +7,6 @@ server initialization. """ -from typing import Dict from fast_agent.config import ( MCPServerSettings, @@ -25,10 +24,10 @@ class ServerRegistry: Attributes: config_path (str): Path to the YAML configuration file. - registry (Dict[str, MCPServerSettings]): Loaded server configurations. + registry (dict[str, MCPServerSettings]): Loaded server configurations. """ - registry: Dict[str, MCPServerSettings] = {} + registry: dict[str, MCPServerSettings] = {} def __init__( self, @@ -47,12 +46,12 @@ def __init__( ## TODO-- leaving this here to support more file formats to add servers def load_registry_from_file( self, config_path: str | None = None - ) -> Dict[str, MCPServerSettings]: + ) -> dict[str, MCPServerSettings]: """ Load the YAML configuration file and validate it. Returns: - Dict[str, MCPServerSettings]: A dictionary of server configurations. + dict[str, MCPServerSettings]: A dictionary of server configurations. Raises: ValueError: If the configuration is invalid. diff --git a/src/fast_agent/skills/registry.py b/src/fast_agent/skills/registry.py index c846bcb53..5eae5db66 100644 --- a/src/fast_agent/skills/registry.py +++ b/src/fast_agent/skills/registry.py @@ -2,7 +2,7 @@ from dataclasses import dataclass, replace from pathlib import Path -from typing import List, Sequence +from typing import Sequence import frontmatter @@ -33,7 +33,7 @@ def __init__( self._base_dir = base_dir or Path.cwd() self._directory: Path | None = None self._override_failed: bool = False - self._errors: List[dict[str, str]] = [] + self._errors: list[dict[str, str]] = [] if override_directory: resolved = self._resolve_directory(override_directory) if resolved and resolved.exists() and resolved.is_dir(): @@ -55,14 +55,14 @@ def directory(self) -> Path | None: def override_failed(self) -> bool: return self._override_failed - def load_manifests(self) -> List[SkillManifest]: + def load_manifests(self) -> list[SkillManifest]: self._errors = [] if not self._directory: return [] manifests = self._load_directory(self._directory, self._errors) # Recompute relative paths to be from base_dir (workspace root) instead of skills directory - adjusted_manifests: List[SkillManifest] = [] + adjusted_manifests: list[SkillManifest] = [] for manifest in manifests: try: relative_path = manifest.path.relative_to(self._base_dir) @@ -74,12 +74,12 @@ def load_manifests(self) -> List[SkillManifest]: return adjusted_manifests - def load_manifests_with_errors(self) -> tuple[List[SkillManifest], List[dict[str, str]]]: + def load_manifests_with_errors(self) -> tuple[list[SkillManifest], list[dict[str, str]]]: manifests = self.load_manifests() return manifests, list(self._errors) @property - def errors(self) -> List[dict[str, str]]: + def errors(self) -> list[dict[str, str]]: return list(self._errors) def _find_default_directory(self) -> Path | None: @@ -95,7 +95,7 @@ def _resolve_directory(self, directory: Path) -> Path: return (self._base_dir / directory).resolve() @classmethod - def load_directory(cls, directory: Path) -> List[SkillManifest]: + def load_directory(cls, directory: Path) -> list[SkillManifest]: if not directory.exists() or not directory.is_dir(): logger.debug( "Skills directory not found", @@ -107,8 +107,8 @@ def load_directory(cls, directory: Path) -> List[SkillManifest]: @classmethod def load_directory_with_errors( cls, directory: Path - ) -> tuple[List[SkillManifest], List[dict[str, str]]]: - errors: List[dict[str, str]] = [] + ) -> tuple[list[SkillManifest], list[dict[str, str]]]: + errors: list[dict[str, str]] = [] manifests = cls._load_directory(directory, errors) return manifests, errors @@ -116,9 +116,9 @@ def load_directory_with_errors( def _load_directory( cls, directory: Path, - errors: List[dict[str, str]] | None = None, - ) -> List[SkillManifest]: - manifests: List[SkillManifest] = [] + errors: list[dict[str, str]] | None = None, + ) -> list[SkillManifest]: + manifests: list[SkillManifest] = [] for entry in sorted(directory.iterdir()): if not entry.is_dir(): continue @@ -201,7 +201,7 @@ def format_skills_for_prompt(manifests: Sequence[SkillManifest]) -> str: "Paths in skill documentation are relative to the skill's directory, not the workspace root. Use the full path from workspace root when executing." "Only use skills listed in below.\n\n" ) - formatted_parts: List[str] = [] + formatted_parts: list[str] = [] for manifest in manifests: description = (manifest.description or "").strip() @@ -210,7 +210,7 @@ def format_skills_for_prompt(manifests: Sequence[SkillManifest]) -> str: if relative_path is None and manifest.path: path_attr = f' path="{manifest.path}"' - block_lines: List[str] = [f''] + block_lines: list[str] = [f''] if description: block_lines.append(f"{description}") block_lines.append("") diff --git a/src/fast_agent/tools/elicitation.py b/src/fast_agent/tools/elicitation.py index 6fa7fde07..a6815117a 100644 --- a/src/fast_agent/tools/elicitation.py +++ b/src/fast_agent/tools/elicitation.py @@ -2,7 +2,7 @@ import json import uuid -from typing import Any, Awaitable, Callable, List, Literal, Optional, Union +from typing import Any, Awaitable, Callable, Literal, Union from mcp.server.fastmcp.tools import Tool as FastMCPTool from mcp.types import Tool as McpTool @@ -28,21 +28,21 @@ class OptionItem(BaseModel): value: Union[str, int, float, bool] - label: Optional[str] = None + label: str | None = None class FormField(BaseModel): name: str type: Literal["text", "textarea", "number", "checkbox", "radio"] - label: Optional[str] = None - help: Optional[str] = None - default: Optional[Union[str, int, float, bool]] = None - required: Optional[bool] = None + label: str | None = None + help: str | None = None + default: Union[str, int, float, bool] | None = None + required: bool | None = None # number constraints - min: Optional[float] = None - max: Optional[float] = None + min: float | None = None + max: float | None = None # select options (for radio) - options: Optional[List[OptionItem]] = None + options: list[OptionItem] | None = None class HumanFormArgs(BaseModel): @@ -51,10 +51,10 @@ class HumanFormArgs(BaseModel): Preferred shape for LLMs. """ - title: Optional[str] = None - description: Optional[str] = None - message: Optional[str] = None - fields: List[FormField] = Field(default_factory=list, max_length=7) + title: str | None = None + description: str | None = None + message: str | None = None + fields: list[FormField] = Field(default_factory=list, max_length=7) # ----------------------- @@ -136,7 +136,7 @@ def _resolve_refs(fragment: Any, root: dict[str, Any]) -> Any: # Elicitation input callback registry # ----------------------- -ElicitationCallback = Callable[[dict, Optional[str], Optional[str], Optional[dict]], Awaitable[str]] +ElicitationCallback = Callable[[dict, str | None, str | None, dict | None], Awaitable[str]] _elicitation_input_callback: ElicitationCallback | None = None @@ -344,10 +344,10 @@ def parse_schema_string(val: str) -> dict | None: def get_elicitation_fastmcp_tool() -> FastMCPTool: async def elicit( - title: Optional[str] = None, - description: Optional[str] = None, - message: Optional[str] = None, - fields: List[FormField] = Field(default_factory=list, max_length=7), + title: str | None = None, + description: str | None = None, + message: str | None = None, + fields: list[FormField] = Field(default_factory=list, max_length=7), ) -> str: args = { "title": title, diff --git a/src/fast_agent/tools/shell_runtime.py b/src/fast_agent/tools/shell_runtime.py index 2a744c542..20dc12321 100644 --- a/src/fast_agent/tools/shell_runtime.py +++ b/src/fast_agent/tools/shell_runtime.py @@ -8,7 +8,7 @@ import subprocess import time from pathlib import Path -from typing import Any, Dict, Optional +from typing import Any from mcp.types import CallToolResult, TextContent, Tool @@ -75,7 +75,7 @@ def working_directory(self) -> Path: return self._skills_directory return Path.cwd() - def runtime_info(self) -> Dict[str, str | None]: + def runtime_info(self) -> dict[str, str | None]: """Best-effort detection of the shell runtime used for local execution. Uses modern Python APIs (platform.system(), shutil.which()) to detect @@ -108,7 +108,7 @@ def runtime_info(self) -> Dict[str, str | None]: # Fallback to generic sh return {"name": "sh", "path": None} - def metadata(self, command: Optional[str]) -> Dict[str, Any]: + def metadata(self, command: str | None) -> dict[str, Any]: """Build metadata for display when the shell tool is invoked.""" info = self.runtime_info() working_dir = self.working_directory() @@ -130,7 +130,7 @@ def metadata(self, command: Optional[str]) -> Dict[str, Any]: "returns_exit_code": True, } - async def execute(self, arguments: Dict[str, Any] | None = None) -> CallToolResult: + async def execute(self, arguments: dict[str, Any] | None = None) -> CallToolResult: """Execute a shell command and stream output to the console with timeout detection.""" command_value = (arguments or {}).get("command") if arguments else None if not isinstance(command_value, str) or not command_value.strip(): @@ -199,7 +199,7 @@ async def execute(self, arguments: Dict[str, Any] | None = None) -> CallToolResu watchdog_task = None async def stream_output( - stream, style: Optional[str], is_stderr: bool = False + stream, style: str | None, is_stderr: bool = False ) -> None: if not stream: return diff --git a/src/fast_agent/types/conversation_summary.py b/src/fast_agent/types/conversation_summary.py index c5d05d1d7..5085c9741 100644 --- a/src/fast_agent/types/conversation_summary.py +++ b/src/fast_agent/types/conversation_summary.py @@ -7,7 +7,6 @@ import json from collections import Counter -from typing import Dict, List from pydantic import BaseModel, computed_field @@ -48,7 +47,7 @@ class ConversationSummary(BaseModel): All computed properties are included in .model_dump() for easy serialization. """ - messages: List[PromptMessageExtended] + messages: list[PromptMessageExtended] @computed_field # type: ignore[prop-decorator] @property @@ -108,13 +107,13 @@ def tool_error_rate(self) -> float: @computed_field # type: ignore[prop-decorator] @property - def tool_call_map(self) -> Dict[str, int]: + def tool_call_map(self) -> dict[str, int]: """ Mapping of tool names to the number of times they were called. Example: {"fetch_weather": 3, "calculate": 1} """ - tool_names: List[str] = [] + tool_names: list[str] = [] for msg in self.messages: if msg.tool_calls: tool_names.extend( @@ -124,7 +123,7 @@ def tool_call_map(self) -> Dict[str, int]: @computed_field # type: ignore[prop-decorator] @property - def tool_error_map(self) -> Dict[str, int]: + def tool_error_map(self) -> dict[str, int]: """ Mapping of tool names to the number of errors they produced. @@ -134,14 +133,14 @@ def tool_error_map(self) -> Dict[str, int]: finding corresponding CallToolRequest entries in assistant messages. """ # First, build a map from tool_id -> tool_name by scanning tool_calls - tool_id_to_name: Dict[str, str] = {} + tool_id_to_name: dict[str, str] = {} for msg in self.messages: if msg.tool_calls: for tool_id, call in msg.tool_calls.items(): tool_id_to_name[tool_id] = call.params.name # Then, count errors by tool name - error_names: List[str] = [] + error_names: list[str] = [] for msg in self.messages: if msg.tool_results: for tool_id, result in msg.tool_results.items(): @@ -194,7 +193,7 @@ def total_elapsed_time_ms(self) -> float: @computed_field # type: ignore[prop-decorator] @property - def assistant_message_timings(self) -> List[Dict[str, float]]: + def assistant_message_timings(self) -> list[dict[str, float]]: """ List of timing data for each assistant message. diff --git a/src/fast_agent/types/message_search.py b/src/fast_agent/types/message_search.py index 3ca39ccf0..63259d08a 100644 --- a/src/fast_agent/types/message_search.py +++ b/src/fast_agent/types/message_search.py @@ -16,7 +16,7 @@ """ import re -from typing import List, Literal, Tuple +from typing import Literal from fast_agent.mcp.helpers.content_helpers import get_text from fast_agent.mcp.prompt_message_extended import PromptMessageExtended @@ -25,10 +25,10 @@ def search_messages( - messages: List[PromptMessageExtended], + messages: list[PromptMessageExtended], pattern: str | re.Pattern, scope: SearchScope = "all", -) -> List[PromptMessageExtended]: +) -> list[PromptMessageExtended]: """ Find messages containing content that matches a pattern. @@ -61,10 +61,10 @@ def search_messages( def find_matches( - messages: List[PromptMessageExtended], + messages: list[PromptMessageExtended], pattern: str | re.Pattern, scope: SearchScope = "all", -) -> List[Tuple[PromptMessageExtended, re.Match]]: +) -> list[tuple[PromptMessageExtended, re.Match]]: """ Find all pattern matches in messages, returning match objects. @@ -103,7 +103,7 @@ def find_matches( def extract_first( - messages: List[PromptMessageExtended], + messages: list[PromptMessageExtended], pattern: str | re.Pattern, scope: SearchScope = "all", group: int = 0, @@ -142,7 +142,7 @@ def extract_first( def extract_last( - messages: List[PromptMessageExtended], + messages: list[PromptMessageExtended], pattern: str | re.Pattern, scope: SearchScope = "all", group: int = 0, @@ -198,7 +198,7 @@ def _find_in_message( msg: PromptMessageExtended, pattern: re.Pattern, scope: SearchScope, -) -> List[re.Match]: +) -> list[re.Match]: """Find all matches of pattern in a message.""" texts = _extract_searchable_text(msg, scope) matches = [] @@ -211,7 +211,7 @@ def _find_in_message( def _extract_searchable_text( msg: PromptMessageExtended, scope: SearchScope, -) -> List[str]: +) -> list[str]: """Extract text from message based on scope.""" texts = [] diff --git a/src/fast_agent/ui/console_display.py b/src/fast_agent/ui/console_display.py index 76710f98f..370f68896 100644 --- a/src/fast_agent/ui/console_display.py +++ b/src/fast_agent/ui/console_display.py @@ -1,6 +1,6 @@ from contextlib import contextmanager from json import JSONDecodeError -from typing import TYPE_CHECKING, Any, Iterator, List, Mapping, Optional, Union +from typing import TYPE_CHECKING, Any, Iterator, Mapping, Union from mcp.types import CallToolResult from rich.markdown import Markdown @@ -109,7 +109,7 @@ def display_message( message_type: MessageType, name: str | None = None, right_info: str = "", - bottom_metadata: List[str] | None = None, + bottom_metadata: list[str] | None = None, highlight_index: int | None = None, max_item_length: int | None = None, is_error: bool = False, @@ -174,7 +174,7 @@ def _display_content( content: Any, truncate: bool = True, is_error: bool = False, - message_type: Optional[MessageType] = None, + message_type: MessageType | None = None, check_markdown_markers: bool = False, ) -> None: """ @@ -370,7 +370,7 @@ def _display_content( else: console.console.print(pretty_obj, markup=self._markup) - def _shorten_items(self, items: List[str], max_length: int) -> List[str]: + def _shorten_items(self, items: list[str], max_length: int) -> list[str]: """ Shorten items to max_length with ellipsis if needed. @@ -387,7 +387,7 @@ def _render_bottom_metadata( self, *, message_type: MessageType, - bottom_metadata: List[str] | None, + bottom_metadata: list[str] | None, highlight_index: int | None, max_item_length: int | None, ) -> None: @@ -437,7 +437,7 @@ def _render_bottom_metadata( def _format_bottom_metadata( self, - items: List[str], + items: list[str], highlight_index: int | None, highlight_color: str, max_width: int | None = None, @@ -609,12 +609,12 @@ def _extract_reasoning_content(self, message: "PromptMessageExtended") -> Text | async def show_assistant_message( self, message_text: Union[str, Text, "PromptMessageExtended"], - bottom_items: List[str] | None = None, + bottom_items: list[str] | None = None, highlight_index: int | None = None, max_item_length: int | None = None, name: str | None = None, model: str | None = None, - additional_message: Optional[Text] = None, + additional_message: Text | None = None, ) -> None: """Display an assistant message in a formatted panel. @@ -674,7 +674,7 @@ async def show_assistant_message( def streaming_assistant_message( self, *, - bottom_items: List[str] | None = None, + bottom_items: list[str] | None = None, highlight_index: int | None = None, max_item_length: int | None = None, name: str | None = None, @@ -718,7 +718,7 @@ def streaming_assistant_message( finally: handle.close() - def _display_mermaid_diagrams(self, diagrams: List[MermaidDiagram]) -> None: + def _display_mermaid_diagrams(self, diagrams: list[MermaidDiagram]) -> None: """Display mermaid diagram links.""" diagram_content = Text() # Add bullet at the beginning @@ -746,7 +746,7 @@ def _display_mermaid_diagrams(self, diagrams: List[MermaidDiagram]) -> None: console.console.print() console.console.print(diagram_content, markup=self._markup) - async def show_mcp_ui_links(self, links: List[UILink]) -> None: + async def show_mcp_ui_links(self, links: list[UILink]) -> None: """Display MCP-UI links beneath the chat like mermaid links.""" if self.config and not self.config.logger.show_chat: return @@ -824,12 +824,12 @@ def show_system_message( async def show_prompt_loaded( self, prompt_name: str, - description: Optional[str] = None, + description: str | None = None, message_count: int = 0, - agent_name: Optional[str] = None, - server_list: List[str] | None = None, + agent_name: str | None = None, + server_list: list[str] | None = None, highlight_server: str | None = None, - arguments: Optional[dict[str, str]] = None, + arguments: dict[str, str] | None = None, ) -> None: """ Display information about a loaded prompt template. diff --git a/src/fast_agent/ui/elicitation_form.py b/src/fast_agent/ui/elicitation_form.py index cef1d4292..da69b8589 100644 --- a/src/fast_agent/ui/elicitation_form.py +++ b/src/fast_agent/ui/elicitation_form.py @@ -2,7 +2,7 @@ import re from datetime import date, datetime -from typing import Any, Dict, Optional +from typing import Any from mcp.types import ElicitRequestedSchema from prompt_toolkit import Application @@ -33,7 +33,7 @@ class SimpleNumberValidator(Validator): """Simple number validator with real-time feedback.""" def __init__( - self, field_type: str, minimum: Optional[float] = None, maximum: Optional[float] = None + self, field_type: str, minimum: float | None = None, maximum: float | None = None ): self.field_type = field_type self.minimum = minimum @@ -69,9 +69,9 @@ class SimpleStringValidator(Validator): def __init__( self, - min_length: Optional[int] = None, - max_length: Optional[int] = None, - pattern: Optional[str] = None, + min_length: int | None = None, + max_length: int | None = None, + pattern: str | None = None, ): self.min_length = min_length self.max_length = max_length @@ -485,7 +485,7 @@ def set_initial_focus(): self.app.invalidate() # Ensure layout is built set_initial_focus() - def _extract_string_constraints(self, field_def: Dict[str, Any]) -> Dict[str, Any]: + def _extract_string_constraints(self, field_def: dict[str, Any]) -> dict[str, Any]: """Extract string constraints from field definition, handling anyOf schemas.""" constraints = {} @@ -511,7 +511,7 @@ def _extract_string_constraints(self, field_def: Dict[str, Any]) -> Dict[str, An return constraints - def _create_field(self, field_name: str, field_def: Dict[str, Any]): + def _create_field(self, field_name: str, field_def: dict[str, Any]): """Create a field widget.""" field_type = field_def.get("type", "string") @@ -695,7 +695,7 @@ def get_dynamic_height(): return HSplit([label, Frame(text_input)]) - def _validate_form(self) -> tuple[bool, Optional[str]]: + def _validate_form(self) -> tuple[bool, str | None]: """Validate the entire form.""" # First, check all fields for validation errors from their validators @@ -728,9 +728,9 @@ def _validate_form(self) -> tuple[bool, Optional[str]]: return True, None - def _get_form_data(self) -> Dict[str, Any]: + def _get_form_data(self) -> dict[str, Any]: """Extract data from form fields.""" - data: Dict[str, Any] = {} + data: dict[str, Any] = {} for field_name, field_def in self.properties.items(): widget = self.field_widgets.get(field_name) @@ -831,7 +831,7 @@ def _clear_status_bar(self): self.app.layout.container = new_layout self.app.invalidate() - async def run_async(self) -> tuple[str, Optional[Dict[str, Any]]]: + async def run_async(self) -> tuple[str, dict[str, Any] | None]: """Run the form and return result.""" try: await self.app.run_async() @@ -844,7 +844,7 @@ async def run_async(self) -> tuple[str, Optional[Dict[str, Any]]]: async def show_simple_elicitation_form( schema: ElicitRequestedSchema, message: str, agent_name: str, server_name: str -) -> tuple[str, Optional[Dict[str, Any]]]: +) -> tuple[str, dict[str, Any] | None]: """Show the simplified elicitation form.""" form = ElicitationForm(schema, message, agent_name, server_name) return await form.run_async() diff --git a/src/fast_agent/ui/enhanced_prompt.py b/src/fast_agent/ui/enhanced_prompt.py index c0e984fcf..de7a6ee23 100644 --- a/src/fast_agent/ui/enhanced_prompt.py +++ b/src/fast_agent/ui/enhanced_prompt.py @@ -10,7 +10,7 @@ import tempfile from importlib.metadata import version from pathlib import Path -from typing import TYPE_CHECKING, Any, Dict, List, Optional +from typing import TYPE_CHECKING, Any from prompt_toolkit import PromptSession from prompt_toolkit.completion import Completer, Completion, WordCompleter @@ -213,7 +213,7 @@ async def _display_agent_info_helper(agent_name: str, agent_provider: "AgentApp async def _display_all_agents_with_hierarchy( - available_agents: List[str], agent_provider: "AgentApp | None" + available_agents: list[str], agent_provider: "AgentApp | None" ) -> None: """Display all agents with tree structure for workflow agents.""" # Track which agents are children to avoid displaying them twice @@ -352,8 +352,8 @@ class AgentCompleter(Completer): def __init__( self, - agents: List[str], - commands: List[str] = None, + agents: list[str], + commands: list[str] = None, agent_types: dict = None, is_human_input: bool = False, ) -> None: @@ -609,7 +609,7 @@ async def get_enhanced_input( show_default: bool = False, show_stop_hint: bool = False, multiline: bool = False, - available_agent_names: List[str] = None, + available_agent_names: list[str] = None, agent_types: dict[str, AgentType] = None, is_human_input: bool = False, toolbar_color: str = "ansiblue", @@ -1097,11 +1097,11 @@ def pre_process_input(text): async def get_selection_input( prompt_text: str, - options: List[str] = None, + options: list[str] = None, default: str = None, allow_cancel: bool = True, complete_options: bool = True, -) -> Optional[str]: +) -> str | None: """ Display a selection prompt and return the user's selection. @@ -1146,7 +1146,7 @@ async def get_argument_input( arg_name: str, description: str = None, required: bool = True, -) -> Optional[str]: +) -> str | None: """ Prompt for an argument value with formatting and help text. @@ -1194,7 +1194,7 @@ async def get_argument_input( async def handle_special_commands( command: Any, agent_app: "AgentApp | None" = None -) -> bool | Dict[str, Any]: +) -> bool | dict[str, Any]: """ Handle special input commands. diff --git a/src/fast_agent/ui/interactive_prompt.py b/src/fast_agent/ui/interactive_prompt.py index 11dfa2bb6..42faa23c6 100644 --- a/src/fast_agent/ui/interactive_prompt.py +++ b/src/fast_agent/ui/interactive_prompt.py @@ -15,7 +15,7 @@ """ from pathlib import Path -from typing import TYPE_CHECKING, Any, Awaitable, Callable, Dict, List, Optional, Union, cast +from typing import TYPE_CHECKING, Any, Awaitable, Callable, Union, cast from fast_agent.constants import CONTROL_MESSAGE_SAVE_HISTORY @@ -46,7 +46,7 @@ SendFunc = Callable[[Union[str, PromptMessage, PromptMessageExtended], str], Awaitable[str]] # Type alias for the agent getter function -AgentGetter = Callable[[str], Optional[object]] +AgentGetter = Callable[[str], object | None] class InteractivePrompt: @@ -55,20 +55,20 @@ class InteractivePrompt: This is extracted from the original AgentApp implementation to support DirectAgentApp. """ - def __init__(self, agent_types: Optional[Dict[str, AgentType]] = None) -> None: + def __init__(self, agent_types: dict[str, AgentType] | None = None) -> None: """ Initialize the interactive prompt. Args: agent_types: Dictionary mapping agent names to their types for display """ - self.agent_types: Dict[str, AgentType] = agent_types or {} + self.agent_types: dict[str, AgentType] = agent_types or {} async def prompt_loop( self, send_func: SendFunc, default_agent: str, - available_agents: List[str], + available_agents: list[str], prompt_provider: "AgentApp", default: str = "", ) -> str: @@ -118,7 +118,7 @@ async def prompt_loop( # Check if we should switch agents if isinstance(command_result, dict): - command_dict: Dict[str, Any] = command_result + command_dict: dict[str, Any] = command_result if "switch_agent" in command_dict: new_agent = command_dict["switch_agent"] if new_agent in available_agents_set: @@ -380,7 +380,7 @@ def _create_combined_separator_status( console.print(combined) rich_print() - async def _get_all_prompts(self, prompt_provider: "AgentApp", agent_name: Optional[str] = None): + async def _get_all_prompts(self, prompt_provider: "AgentApp", agent_name: str | None = None): """ Get a list of all available prompts. @@ -566,8 +566,8 @@ async def _select_prompt( self, prompt_provider: "AgentApp", agent_name: str, - requested_name: Optional[str] = None, - send_func: Optional[SendFunc] = None, + requested_name: str | None = None, + send_func: SendFunc | None = None, ) -> None: """ Select and apply a prompt. @@ -597,7 +597,7 @@ async def _select_prompt( continue # Extract prompts - prompts: List[Prompt] = [] + prompts: list[Prompt] = [] if hasattr(prompts_info, "prompts"): prompts = prompts_info.prompts elif isinstance(prompts_info, list): diff --git a/src/fast_agent/ui/markdown_truncator.py b/src/fast_agent/ui/markdown_truncator.py index 2f4658653..1c7e61212 100644 --- a/src/fast_agent/ui/markdown_truncator.py +++ b/src/fast_agent/ui/markdown_truncator.py @@ -30,7 +30,7 @@ """ from dataclasses import dataclass -from typing import Iterable, List, Optional +from typing import Iterable from markdown_it import MarkdownIt from markdown_it.token import Token @@ -71,7 +71,7 @@ class TableInfo: thead_end_pos: int tbody_start_pos: int tbody_end_pos: int - header_lines: List[str] # Header + separator rows + header_lines: list[str] # Header + separator rows class MarkdownTruncator: @@ -92,11 +92,11 @@ def __init__(self, target_height_ratio: float = 0.8): self._last_terminal_height: int | None = None # Markdown parse cache self._cache_source: str | None = None - self._cache_tokens: List[Token] | None = None - self._cache_lines: List[str] | None = None - self._cache_safe_points: List[TruncationPoint] | None = None - self._cache_code_blocks: List[CodeBlockInfo] | None = None - self._cache_tables: List[TableInfo] | None = None + self._cache_tokens: list[Token] | None = None + self._cache_lines: list[str] | None = None + self._cache_safe_points: list[TruncationPoint] | None = None + self._cache_code_blocks: list[CodeBlockInfo] | None = None + self._cache_tables: list[TableInfo] | None = None def truncate( self, @@ -305,7 +305,7 @@ def _ensure_parse_cache(self, text: str) -> None: self._cache_code_blocks = None self._cache_tables = None - def _find_safe_truncation_points(self, text: str) -> List[TruncationPoint]: + def _find_safe_truncation_points(self, text: str) -> list[TruncationPoint]: """Find safe positions to truncate at (block boundaries). Args: @@ -321,7 +321,7 @@ def _find_safe_truncation_points(self, text: str) -> List[TruncationPoint]: assert self._cache_tokens is not None assert self._cache_lines is not None - safe_points: List[TruncationPoint] = [] + safe_points: list[TruncationPoint] = [] tokens = self._cache_tokens lines = self._cache_lines @@ -347,7 +347,7 @@ def _find_safe_truncation_points(self, text: str) -> List[TruncationPoint]: self._cache_safe_points = safe_points return list(safe_points) - def _get_code_block_info(self, text: str) -> List[CodeBlockInfo]: + def _get_code_block_info(self, text: str) -> list[CodeBlockInfo]: """Extract code block positions and metadata using markdown-it. Uses same technique as prepare_markdown_content in markdown_helpers.py: @@ -368,7 +368,7 @@ def _get_code_block_info(self, text: str) -> List[CodeBlockInfo]: tokens = self._cache_tokens lines = self._cache_lines - code_blocks: List[CodeBlockInfo] = [] + code_blocks: list[CodeBlockInfo] = [] for token in self._flatten_tokens(tokens): if token.type in ("fence", "code_block") and token.map: @@ -417,7 +417,7 @@ def _build_code_block_prefix(self, block: CodeBlockInfo) -> str | None: return None - def _get_table_info(self, text: str) -> List[TableInfo]: + def _get_table_info(self, text: str) -> list[TableInfo]: """Extract table positions and metadata using markdown-it. Uses same technique as _get_code_block_info: parse once with markdown-it, @@ -438,7 +438,7 @@ def _get_table_info(self, text: str) -> List[TableInfo]: tokens = self._cache_tokens lines = self._cache_lines - tables: List[TableInfo] = [] + tables: list[TableInfo] = [] for i, token in enumerate(tokens): if token.type == "table_open" and token.map: @@ -511,12 +511,12 @@ def _get_table_info(self, text: str) -> List[TableInfo]: def _find_best_truncation_point( self, text: str, - safe_points: List[TruncationPoint], + safe_points: list[TruncationPoint], target_height: int, console: Console, code_theme: str, keep_beginning: bool = False, - ) -> Optional[TruncationPoint]: + ) -> TruncationPoint | None: """Find the truncation point that gets closest to target height. Args: diff --git a/src/fast_agent/ui/mcp_ui_utils.py b/src/fast_agent/ui/mcp_ui_utils.py index 946c250a1..bfe8ac194 100644 --- a/src/fast_agent/ui/mcp_ui_utils.py +++ b/src/fast_agent/ui/mcp_ui_utils.py @@ -7,7 +7,7 @@ import webbrowser from dataclasses import dataclass from pathlib import Path -from typing import Iterable, List, Optional +from typing import Iterable from mcp.types import BlobResourceContents, EmbeddedResource, TextResourceContents @@ -30,7 +30,7 @@ class UILink: title: str file_path: str # absolute path to local html file - web_url: Optional[str] = None # Preferable clickable link (http(s) or data URL) + web_url: str | None = None # Preferable clickable link (http(s) or data URL) def _safe_filename(name: str) -> str: @@ -69,7 +69,7 @@ def _extract_title(uri: str | None) -> str: return "UI" -def _decode_text_or_blob(resource) -> Optional[str]: +def _decode_text_or_blob(resource) -> str | None: """Return string content from TextResourceContents or BlobResourceContents.""" if isinstance(resource, TextResourceContents): return resource.text or "" @@ -81,7 +81,7 @@ def _decode_text_or_blob(resource) -> Optional[str]: return None -def _first_https_url_from_uri_list(text: str) -> Optional[str]: +def _first_https_url_from_uri_list(text: str) -> str | None: for line in text.splitlines(): line = line.strip() if not line or line.startswith("#"): @@ -126,7 +126,7 @@ def _write_html_file(name_hint: str, html: str) -> str: return str(out_path.resolve()) -def ui_links_from_channel(resources: Iterable[EmbeddedResource]) -> List[UILink]: +def ui_links_from_channel(resources: Iterable[EmbeddedResource]) -> list[UILink]: """ Build local HTML files for a list of MCP-UI EmbeddedResources and return clickable links. @@ -135,7 +135,7 @@ def ui_links_from_channel(resources: Iterable[EmbeddedResource]) -> List[UILink] - text/uri-list: expects text or blob of a single URL (first valid URL is used) - application/vnd.mcp-ui.remote-dom* : currently unsupported; generate a placeholder page """ - links: List[UILink] = [] + links: list[UILink] = [] for emb in resources: res = emb.resource uri = str(getattr(res, "uri", "")) if getattr(res, "uri", None) else None diff --git a/src/fast_agent/ui/mermaid_utils.py b/src/fast_agent/ui/mermaid_utils.py index 666f8123f..7e0271cfb 100644 --- a/src/fast_agent/ui/mermaid_utils.py +++ b/src/fast_agent/ui/mermaid_utils.py @@ -4,7 +4,6 @@ import re import zlib from dataclasses import dataclass -from typing import List, Optional # Mermaid chart viewer URL prefix MERMAID_VIEWER_URL = "https://www.mermaidchart.com/play#" @@ -16,12 +15,12 @@ class MermaidDiagram: """Represents a detected Mermaid diagram.""" content: str - title: Optional[str] = None + title: str | None = None start_pos: int = 0 end_pos: int = 0 -def extract_mermaid_diagrams(text: str) -> List[MermaidDiagram]: +def extract_mermaid_diagrams(text: str) -> list[MermaidDiagram]: """ Extract all Mermaid diagram blocks from text content. @@ -102,7 +101,7 @@ def create_mermaid_live_link(diagram_content: str) -> str: return f"{MERMAID_VIEWER_URL}pako:{encoded}" -def format_mermaid_links(diagrams: List[MermaidDiagram]) -> List[str]: +def format_mermaid_links(diagrams: list[MermaidDiagram]) -> list[str]: """ Format Mermaid diagrams as markdown links. diff --git a/src/fast_agent/ui/notification_tracker.py b/src/fast_agent/ui/notification_tracker.py index a03a4a6cb..1b39e2a7c 100644 --- a/src/fast_agent/ui/notification_tracker.py +++ b/src/fast_agent/ui/notification_tracker.py @@ -4,7 +4,6 @@ """ from datetime import datetime -from typing import Dict, List, Optional # Display metadata for toolbar summaries (singular, plural, compact label) _EVENT_ORDER = ("tool_update", "sampling", "elicitation") @@ -15,10 +14,10 @@ } # Active events currently in progress -active_events: Dict[str, Dict[str, str]] = {} +active_events: dict[str, dict[str, str]] = {} # Completed notifications history -notifications: List[Dict[str, str]] = [] +notifications: list[dict[str, str]] = [] def add_tool_update(server_name: str) -> None: @@ -115,7 +114,7 @@ def end_elicitation(server_name: str) -> None: pass -def get_active_status() -> Optional[Dict[str, str]]: +def get_active_status() -> dict[str, str] | None: """Get currently active operation, if any. Returns: @@ -139,14 +138,14 @@ def get_count() -> int: return len(notifications) -def get_latest() -> Dict[str, str] | None: +def get_latest() -> dict[str, str] | None: """Get the most recent completed notification.""" return notifications[-1] if notifications else None -def get_counts_by_type() -> Dict[str, int]: +def get_counts_by_type() -> dict[str, int]: """Aggregate completed notifications by event type.""" - counts: Dict[str, int] = {} + counts: dict[str, int] = {} for notification in notifications: event_type = notification['type'] counts[event_type] = counts.get(event_type, 0) + 1 @@ -154,7 +153,7 @@ def get_counts_by_type() -> Dict[str, int]: if not counts: return {} - ordered: Dict[str, int] = {} + ordered: dict[str, int] = {} for event_type in _EVENT_ORDER: if event_type in counts: ordered[event_type] = counts[event_type] diff --git a/src/fast_agent/ui/rich_progress.py b/src/fast_agent/ui/rich_progress.py index 90f9ae4a5..c180e40ff 100644 --- a/src/fast_agent/ui/rich_progress.py +++ b/src/fast_agent/ui/rich_progress.py @@ -2,7 +2,7 @@ import time from contextlib import contextmanager -from typing import Any, Optional +from typing import Any from rich.console import Console from rich.progress import Progress, SpinnerColumn, TaskID, TextColumn @@ -14,7 +14,7 @@ class RichProgressDisplay: """Rich-based display for progress events.""" - def __init__(self, console: Optional[Console] = None) -> None: + def __init__(self, console: Console | None = None) -> None: """Initialize the progress display.""" self.console = console or default_console self._taskmap: dict[str, TaskID] = {} diff --git a/src/fast_agent/ui/streaming_buffer.py b/src/fast_agent/ui/streaming_buffer.py index 59c4a1ae0..28b1f997a 100644 --- a/src/fast_agent/ui/streaming_buffer.py +++ b/src/fast_agent/ui/streaming_buffer.py @@ -19,7 +19,7 @@ from dataclasses import dataclass from math import ceil -from typing import Generator, List, Optional +from typing import Generator from markdown_it import MarkdownIt from markdown_it.token import Token @@ -40,7 +40,7 @@ class Table: start_pos: int # Character position where table starts end_pos: int # Character position where table ends - header_lines: List[str] # Header row + separator (e.g., ["| A | B |", "|---|---|"]) + header_lines: list[str] # Header row + separator (e.g., ["| A | B |", "|---|---|"]) class StreamBuffer: @@ -56,7 +56,7 @@ class StreamBuffer: def __init__(self): """Initialize the stream buffer.""" - self._chunks: List[str] = [] + self._chunks: list[str] = [] self._parser = MarkdownIt().enable("table") def append(self, chunk: str) -> None: @@ -80,7 +80,7 @@ def get_display_text( self, terminal_height: int, target_ratio: float = 0.7, - terminal_width: Optional[int] = None, + terminal_width: int | None = None, ) -> str: """Get text for display, truncated to fit terminal. @@ -114,7 +114,7 @@ def _truncate_for_display( text: str, terminal_height: int, target_ratio: float, - terminal_width: Optional[int], + terminal_width: int | None, ) -> str: """Truncate text to fit display with context preservation. @@ -208,7 +208,7 @@ def _truncate_for_display( return truncated_text - def _find_code_blocks(self, text: str) -> List[CodeBlock]: + def _find_code_blocks(self, text: str) -> list[CodeBlock]: """Find all code blocks in text using markdown-it parser. Args: @@ -235,7 +235,7 @@ def _find_code_blocks(self, text: str) -> List[CodeBlock]: return blocks - def _find_tables(self, text: str) -> List[Table]: + def _find_tables(self, text: str) -> list[Table]: """Find all tables in text using markdown-it parser. Args: @@ -279,7 +279,7 @@ def _find_tables(self, text: str) -> List[Table]: return tables def _preserve_code_block_context( - self, original_text: str, truncated_text: str, truncation_pos: int, code_blocks: List[CodeBlock] + self, original_text: str, truncated_text: str, truncation_pos: int, code_blocks: list[CodeBlock] ) -> str: """Prepend code block opening fence if truncation removed it. @@ -310,7 +310,7 @@ def _preserve_code_block_context( return truncated_text def _preserve_table_context( - self, original_text: str, truncated_text: str, truncation_pos: int, tables: List[Table] + self, original_text: str, truncated_text: str, truncation_pos: int, tables: list[Table] ) -> str: """Prepend table header if truncation removed it. @@ -382,7 +382,7 @@ def _add_closing_fence_if_needed(self, text: str) -> str: return text - def _flatten_tokens(self, tokens: List[Token]) -> Generator[Token, None, None]: + def _flatten_tokens(self, tokens: list[Token]) -> Generator[Token, None, None]: """Flatten nested token tree. Args: @@ -399,7 +399,7 @@ def _flatten_tokens(self, tokens: List[Token]) -> Generator[Token, None, None]: else: yield token - def _estimate_display_counts(self, lines: List[str], terminal_width: int) -> List[int]: + def _estimate_display_counts(self, lines: list[str], terminal_width: int) -> list[int]: """Estimate how many terminal rows each logical line will occupy.""" return [ max(1, ceil(len(line) / terminal_width)) if line else 1 diff --git a/src/fast_agent/ui/usage_display.py b/src/fast_agent/ui/usage_display.py index f922a6328..c866b303b 100644 --- a/src/fast_agent/ui/usage_display.py +++ b/src/fast_agent/ui/usage_display.py @@ -3,13 +3,13 @@ Consolidates the usage display logic that was duplicated between fastagent.py and interactive_prompt.py. """ -from typing import Any, Dict, Optional +from typing import Any from rich.console import Console def display_usage_report( - agents: Dict[str, Any], show_if_progress_disabled: bool = False, subdued_colors: bool = False + agents: dict[str, Any], show_if_progress_disabled: bool = False, subdued_colors: bool = False ) -> None: """ Display a formatted table of token usage for all agents. @@ -179,8 +179,8 @@ def display_usage_report( def collect_agents_from_provider( - prompt_provider: Any, agent_name: Optional[str] = None -) -> Dict[str, Any]: + prompt_provider: Any, agent_name: str | None = None +) -> dict[str, Any]: """ Collect agents from a prompt provider for usage display. diff --git a/src/fast_agent/workflow_telemetry.py b/src/fast_agent/workflow_telemetry.py index 52a811872..534b61074 100644 --- a/src/fast_agent/workflow_telemetry.py +++ b/src/fast_agent/workflow_telemetry.py @@ -13,7 +13,7 @@ import asyncio from contextlib import AbstractAsyncContextManager from dataclasses import dataclass -from typing import TYPE_CHECKING, Any, Optional, Protocol +from typing import TYPE_CHECKING, Any, Protocol from mcp.types import ContentBlock, TextContent @@ -117,7 +117,7 @@ class _ToolHandlerWorkflowStep(WorkflowTelemetry): server_name: str arguments: dict[str, Any] | None - _tool_call_id: Optional[str] = None + _tool_call_id: str | None = None _finished: bool = False _lock: asyncio.Lock = asyncio.Lock() diff --git a/tests/e2e/bedrock/test_dynamic_capabilities.py b/tests/e2e/bedrock/test_dynamic_capabilities.py index db6c1a388..a74eda916 100644 --- a/tests/e2e/bedrock/test_dynamic_capabilities.py +++ b/tests/e2e/bedrock/test_dynamic_capabilities.py @@ -1,5 +1,4 @@ import sys -from typing import List import pytest @@ -16,7 +15,7 @@ def debug_cache_at_end(): BedrockLLM.debug_cache() -def _bedrock_models_for_capability_tests() -> List[str]: +def _bedrock_models_for_capability_tests() -> list[str]: """Return Bedrock models if AWS is configured, otherwise return empty list.""" try: return all_bedrock_models(prefix="") diff --git a/tests/e2e/bedrock/test_e2e_smoke_bedrock.py b/tests/e2e/bedrock/test_e2e_smoke_bedrock.py index 6c2758fd6..3a5285297 100644 --- a/tests/e2e/bedrock/test_e2e_smoke_bedrock.py +++ b/tests/e2e/bedrock/test_e2e_smoke_bedrock.py @@ -1,5 +1,5 @@ import sys -from typing import Annotated, List +from typing import Annotated import pytest from pydantic import BaseModel, Field @@ -19,7 +19,7 @@ def debug_cache_at_end(): BedrockLLM.debug_cache() -def _bedrock_models_for_smoke() -> List[str]: +def _bedrock_models_for_smoke() -> list[str]: """Return Bedrock models if AWS is configured, otherwise return empty list.""" try: return all_bedrock_models(prefix="") @@ -204,7 +204,7 @@ async def weather_forecast(): await weather_forecast() -def _bedrock_models_for_structured() -> List[str]: +def _bedrock_models_for_structured() -> list[str]: """Return Bedrock models suitable for structured-output tests. Prefer Nova and Claude 3.x families. diff --git a/tests/e2e/smoke/base/test_e2e_smoke.py b/tests/e2e/smoke/base/test_e2e_smoke.py index 86c69a15e..e9ccee881 100644 --- a/tests/e2e/smoke/base/test_e2e_smoke.py +++ b/tests/e2e/smoke/base/test_e2e_smoke.py @@ -1,6 +1,6 @@ import os from enum import Enum -from typing import TYPE_CHECKING, List +from typing import TYPE_CHECKING import pytest from pydantic import BaseModel, Field @@ -176,7 +176,7 @@ class WeatherForecast(BaseModel): location: str = Field(..., description="City and country") unit: TemperatureUnit = Field(..., description="Temperature unit") - forecast: List[DailyForecast] = Field(..., description="Daily forecasts") + forecast: list[DailyForecast] = Field(..., description="Daily forecasts") summary: str = Field(..., description="Brief summary of the overall forecast") diff --git a/tests/integration/acp/test_acp_slash_commands.py b/tests/integration/acp/test_acp_slash_commands.py index 1bd85e124..b97c2af66 100644 --- a/tests/integration/acp/test_acp_slash_commands.py +++ b/tests/integration/acp/test_acp_slash_commands.py @@ -5,7 +5,7 @@ import sys from dataclasses import dataclass, field from pathlib import Path -from typing import TYPE_CHECKING, Any, Dict, List, Optional, cast +from typing import TYPE_CHECKING, Any, cast import pytest from mcp.types import TextContent @@ -30,7 +30,7 @@ @dataclass class StubAgent: - message_history: List[Any] = field(default_factory=list) + message_history: list[Any] = field(default_factory=list) _llm: Any = None cleared: bool = False popped: bool = False @@ -48,7 +48,7 @@ def pop_last_message(self): @dataclass class StubAgentInstance: - agents: Dict[str, Any] = field(default_factory=dict) + agents: dict[str, Any] = field(default_factory=dict) def _handler( @@ -220,9 +220,9 @@ async def test_slash_command_save_conversation() -> None: class RecordingHistoryExporter: def __init__(self, default_name: str = "24_01_01_12_00-conversation.json") -> None: self.default_name = default_name - self.calls: List[tuple[Any, Optional[str]]] = [] + self.calls: list[tuple[Any, str | None]] = [] - async def save(self, agent, filename: Optional[str] = None) -> str: + async def save(self, agent, filename: str | None = None) -> str: self.calls.append((agent, filename)) return filename or self.default_name diff --git a/tests/integration/elicitation/elicitation_test_server_advanced.py b/tests/integration/elicitation/elicitation_test_server_advanced.py index 85eb98291..99a47801f 100644 --- a/tests/integration/elicitation/elicitation_test_server_advanced.py +++ b/tests/integration/elicitation/elicitation_test_server_advanced.py @@ -4,7 +4,6 @@ import logging import sys -from typing import Optional from mcp import ( ReadResourceResult, @@ -129,7 +128,7 @@ class UserProfile(BaseModel): ], }, ) - email: Optional[str] = Field( + email: str | None = Field( None, description="Your email address (optional)", json_schema_extra={"format": "email"} ) subscribe_newsletter: bool = Field(False, description="Subscribe to our newsletter?") @@ -210,7 +209,7 @@ class Feedback(BaseModel): overall_rating: int = Field(description="Overall rating (1-5)", ge=1, le=5) ease_of_use: float = Field(description="Ease of use (0.0-10.0)", ge=0.0, le=10.0) would_recommend: bool = Field(description="Would you recommend to others?") - comments: Optional[str] = Field(None, description="Additional comments", max_length=500) + comments: str | None = Field(None, description="Additional comments", max_length=500) result = await mcp.get_context().elicit("We'd love your feedback!", schema=Feedback) diff --git a/tests/integration/elicitation/test_elicitation_handler.py b/tests/integration/elicitation/test_elicitation_handler.py index 861ffeb74..aa79739ab 100644 --- a/tests/integration/elicitation/test_elicitation_handler.py +++ b/tests/integration/elicitation/test_elicitation_handler.py @@ -5,7 +5,7 @@ to verify custom handler functionality. """ -from typing import TYPE_CHECKING, Any, Dict +from typing import TYPE_CHECKING, Any from mcp.shared.context import RequestContext from mcp.types import ElicitRequestParams, ElicitResult @@ -28,7 +28,7 @@ async def custom_elicitation_handler( if params.requestedSchema: # Generate test data based on the schema for round-trip verification properties = params.requestedSchema.get("properties", {}) - content: Dict[str, Any] = {} + content: dict[str, Any] = {} # Provide test values for each field for field_name, field_def in properties.items(): diff --git a/tests/integration/elicitation/test_elicitation_integration.py b/tests/integration/elicitation/test_elicitation_integration.py index 3a4e68ccc..b8f630273 100644 --- a/tests/integration/elicitation/test_elicitation_integration.py +++ b/tests/integration/elicitation/test_elicitation_integration.py @@ -7,7 +7,7 @@ 3. Elicitation capabilities are properly advertised to servers """ -from typing import TYPE_CHECKING, Any, Dict +from typing import TYPE_CHECKING, Any import pytest from mcp.shared.context import RequestContext @@ -31,7 +31,7 @@ async def custom_test_elicitation_handler( if params.requestedSchema: # Generate test data based on the schema for round-trip verification properties = params.requestedSchema.get("properties", {}) - content: Dict[str, Any] = {} + content: dict[str, Any] = {} # Provide test values for each field for field_name, field_def in properties.items(): diff --git a/tests/integration/elicitation/testing_handlers.py b/tests/integration/elicitation/testing_handlers.py index 394de273f..22877d95e 100644 --- a/tests/integration/elicitation/testing_handlers.py +++ b/tests/integration/elicitation/testing_handlers.py @@ -5,7 +5,7 @@ where you need predictable, automated responses. """ -from typing import TYPE_CHECKING, Any, Dict +from typing import TYPE_CHECKING, Any from mcp.shared.context import RequestContext from mcp.types import ElicitRequestParams, ElicitResult @@ -55,7 +55,7 @@ async def auto_cancel_test_handler( return ElicitResult(action="cancel") -def _generate_test_response(schema: Dict[str, Any]) -> Dict[str, Any]: +def _generate_test_response(schema: dict[str, Any]) -> dict[str, Any]: """Generate realistic test data based on JSON schema.""" if not schema or "properties" not in schema: return {"response": "default-test"} diff --git a/tests/integration/prompt-server/test_prompt_server_integration.py b/tests/integration/prompt-server/test_prompt_server_integration.py index 4606c33f8..f04798436 100644 --- a/tests/integration/prompt-server/test_prompt_server_integration.py +++ b/tests/integration/prompt-server/test_prompt_server_integration.py @@ -1,5 +1,5 @@ from pathlib import Path -from typing import TYPE_CHECKING, Dict, List +from typing import TYPE_CHECKING import pytest @@ -109,7 +109,7 @@ async def test_agent_interface_returns_prompts_list(fast_agent): @fast.agent(name="test", servers=["prompts"]) async def agent_function(): async with fast.run() as agent: - prompts: Dict[str, List[Prompt]] = await agent.test.list_prompts() + prompts: dict[str, list[Prompt]] = await agent.test.list_prompts() assert 5 == len(prompts["prompts"]) await agent_function() diff --git a/tests/integration/prompt-state/test_load_prompt_templates.py b/tests/integration/prompt-state/test_load_prompt_templates.py index a1693def4..dbeccfc7f 100644 --- a/tests/integration/prompt-state/test_load_prompt_templates.py +++ b/tests/integration/prompt-state/test_load_prompt_templates.py @@ -1,6 +1,6 @@ import os from pathlib import Path -from typing import TYPE_CHECKING, List +from typing import TYPE_CHECKING import pytest from mcp.types import ImageContent @@ -26,7 +26,7 @@ async def test_load_simple_conversation_from_file(fast_agent): @fast.agent() async def agent_function(): async with fast.run() as agent: - loaded: List[PromptMessageExtended] = load_prompt(Path("conv1_simple.md")) + loaded: list[PromptMessageExtended] = load_prompt(Path("conv1_simple.md")) assert 4 == len(loaded) assert "user" == loaded[0].role assert "assistant" == loaded[1].role diff --git a/tests/integration/tool_loop/test_tool_loop.py b/tests/integration/tool_loop/test_tool_loop.py index cd5289c96..921976e3c 100644 --- a/tests/integration/tool_loop/test_tool_loop.py +++ b/tests/integration/tool_loop/test_tool_loop.py @@ -1,4 +1,3 @@ -from typing import List import pytest from mcp import CallToolRequest, Tool @@ -17,7 +16,7 @@ class ToolGeneratingLlm(PassthroughLLM): async def _apply_prompt_provider_specific( self, - multipart_messages: List[PromptMessageExtended], + multipart_messages: list[PromptMessageExtended], request_params: RequestParams | None = None, tools: list[Tool] | None = None, is_template: bool = False, @@ -99,7 +98,7 @@ def __init__(self, **kwargs): async def _apply_prompt_provider_specific( self, - multipart_messages: List[PromptMessageExtended], + multipart_messages: list[PromptMessageExtended], request_params: RequestParams | None = None, tools: list[Tool] | None = None, is_template: bool = False, diff --git a/tests/unit/fast_agent/agents/test_mcp_agent_local_tools.py b/tests/unit/fast_agent/agents/test_mcp_agent_local_tools.py index 54ace2ed9..dfae7ceb8 100644 --- a/tests/unit/fast_agent/agents/test_mcp_agent_local_tools.py +++ b/tests/unit/fast_agent/agents/test_mcp_agent_local_tools.py @@ -1,4 +1,4 @@ -from typing import TYPE_CHECKING, Any, Dict, List +from typing import TYPE_CHECKING, Any import pytest @@ -16,7 +16,7 @@ def _make_agent_config() -> AgentConfig: @pytest.mark.asyncio async def test_local_tools_listed_and_callable() -> None: - calls: List[Dict[str, Any]] = [] + calls: list[dict[str, Any]] = [] def sample_tool(video_id: str) -> str: calls.append({"video_id": video_id}) diff --git a/tests/unit/fast_agent/llm/providers/test_llm_azure.py b/tests/unit/fast_agent/llm/providers/test_llm_azure.py index 153d303aa..53de726b1 100644 --- a/tests/unit/fast_agent/llm/providers/test_llm_azure.py +++ b/tests/unit/fast_agent/llm/providers/test_llm_azure.py @@ -1,5 +1,4 @@ import types -from typing import Optional import pytest @@ -12,11 +11,11 @@ class DummyLogger: class DummyAzureConfig: def __init__(self): - self.api_key: Optional[str] = "test-key" - self.resource_name: Optional[str] = "test-resource" - self.azure_deployment: Optional[str] = "test-deployment" - self.api_version: Optional[str] = "2023-05-15" - self.base_url: Optional[str] = None + self.api_key: str | None = "test-key" + self.resource_name: str | None = "test-resource" + self.azure_deployment: str | None = "test-deployment" + self.api_version: str | None = "2023-05-15" + self.base_url: str | None = None self.use_default_azure_credential: bool = False def get(self, key, default=None): diff --git a/tests/unit/fast_agent/llm/providers/test_llm_tensorzero_unit.py b/tests/unit/fast_agent/llm/providers/test_llm_tensorzero_unit.py index e57692249..a7ac222e1 100644 --- a/tests/unit/fast_agent/llm/providers/test_llm_tensorzero_unit.py +++ b/tests/unit/fast_agent/llm/providers/test_llm_tensorzero_unit.py @@ -1,4 +1,3 @@ -from typing import List from unittest.mock import MagicMock, patch import pytest @@ -70,7 +69,7 @@ def test_base_url_uses_default_when_config_missing(mock_agent): @patch("fast_agent.llm.provider.openai.llm_openai.OpenAILLM._prepare_api_request") def test_prepare_api_request_with_template_vars(mock_super_prepare, t0_llm): """Tests injection of template_vars into a new system message.""" - messages: List[ChatCompletionMessageParam] = [] + messages: list[ChatCompletionMessageParam] = [] # The super call's return value has its own 'messages' list. We ignore it. mock_super_prepare.return_value = {"model": "test_chat", "messages": []} request_params = RequestParams(template_vars={"var1": "value1"}) @@ -91,7 +90,7 @@ def test_prepare_api_request_merges_metadata(mock_super_prepare, t0_llm): initial_system_message = ChatCompletionSystemMessageParam( role="system", content=[{"var1": "original"}] ) - messages: List[ChatCompletionMessageParam] = [initial_system_message] + messages: list[ChatCompletionMessageParam] = [initial_system_message] mock_super_prepare.return_value = {"model": "test_chat", "messages": messages} request_params = RequestParams(metadata={"tensorzero_arguments": {"var2": "metadata_val"}}) @@ -120,7 +119,7 @@ def test_prepare_api_request_all_features(mock_super_prepare, t0_llm): initial_system_message = ChatCompletionSystemMessageParam( role="system", content="Original prompt" ) - messages: List[ChatCompletionMessageParam] = [initial_system_message] + messages: list[ChatCompletionMessageParam] = [initial_system_message] mock_super_prepare.return_value = { "model": "test_chat", "messages": messages, diff --git a/tests/unit/fast_agent/llm/providers/test_multipart_converter_google.py b/tests/unit/fast_agent/llm/providers/test_multipart_converter_google.py index 7636a5b6b..a9fcbad85 100644 --- a/tests/unit/fast_agent/llm/providers/test_multipart_converter_google.py +++ b/tests/unit/fast_agent/llm/providers/test_multipart_converter_google.py @@ -1,6 +1,6 @@ import base64 import unittest -from typing import TYPE_CHECKING, List +from typing import TYPE_CHECKING from mcp.types import ( CallToolResult, @@ -32,7 +32,7 @@ def test_tool_result_conversion(self): # tool_call_id = "call_abc123" # Convert directly to OpenAI tool message - converted: List[Content] = self.converter.convert_function_results_to_google( + converted: list[Content] = self.converter.convert_function_results_to_google( [("test", tool_result)] ) assert 1 == len(converted) @@ -62,7 +62,7 @@ def test_multiple_tool_results_with_mixed_content(self): results = [(tool_call_id1, text_result), (tool_call_id2, image_result)] # Convert to OpenAI tool messages - converted: List[Content] = self.converter.convert_function_results_to_google(results) + converted: list[Content] = self.converter.convert_function_results_to_google(results) # Assertions assert 2 == len(converted) diff --git a/tests/unit/fast_agent/llm/test_cache_control_application.py b/tests/unit/fast_agent/llm/test_cache_control_application.py index be82832df..4aff09655 100644 --- a/tests/unit/fast_agent/llm/test_cache_control_application.py +++ b/tests/unit/fast_agent/llm/test_cache_control_application.py @@ -1,5 +1,5 @@ import unittest -from typing import Any, Dict +from typing import Any from mcp.types import TextContent @@ -7,7 +7,7 @@ from fast_agent.mcp.prompt_message_extended import PromptMessageExtended -def apply_cache_control_to_message(message: Dict[str, Any], position: int) -> bool: +def apply_cache_control_to_message(message: dict[str, Any], position: int) -> bool: """ Apply cache control to a message at the specified position. diff --git a/tests/unit/fast_agent/llm/test_cache_walking_real_messages.py b/tests/unit/fast_agent/llm/test_cache_walking_real_messages.py index 6d9ccf246..cbbe2a206 100644 --- a/tests/unit/fast_agent/llm/test_cache_walking_real_messages.py +++ b/tests/unit/fast_agent/llm/test_cache_walking_real_messages.py @@ -1,15 +1,15 @@ import unittest -from typing import Any, Dict, List +from typing import Any from fast_agent.llm.memory import SimpleMemory -def create_message(role: str, text: str, turn: int = 0) -> Dict[str, Any]: +def create_message(role: str, text: str, turn: int = 0) -> dict[str, Any]: """Create a realistic message dict.""" return {"role": role, "content": [{"type": "text", "text": f"{text} (turn {turn})"}]} -def create_tool_response_message(tool_result: str, turn: int = 0) -> Dict[str, Any]: +def create_tool_response_message(tool_result: str, turn: int = 0) -> dict[str, Any]: """Create a tool response message.""" return { "role": "user", @@ -17,7 +17,7 @@ def create_tool_response_message(tool_result: str, turn: int = 0) -> Dict[str, A } -def has_cache_control(message: Dict[str, Any]) -> bool: +def has_cache_control(message: dict[str, Any]) -> bool: """Check if a message has cache control.""" if not isinstance(message, dict) or "content" not in message: return False @@ -30,7 +30,7 @@ def has_cache_control(message: Dict[str, Any]) -> bool: return False -def count_cache_blocks(messages: List[Dict[str, Any]]) -> int: +def count_cache_blocks(messages: list[dict[str, Any]]) -> int: """Count total cache blocks in message array.""" return sum(1 for msg in messages if has_cache_control(msg)) diff --git a/tests/unit/fast_agent/llm/test_prepare_arguments.py b/tests/unit/fast_agent/llm/test_prepare_arguments.py index 287bacc6c..8945abe1c 100644 --- a/tests/unit/fast_agent/llm/test_prepare_arguments.py +++ b/tests/unit/fast_agent/llm/test_prepare_arguments.py @@ -1,4 +1,4 @@ -from typing import Any, List +from typing import Any from fast_agent.llm.fastagent_llm import FastAgentLLM from fast_agent.llm.provider.anthropic.llm_anthropic import AnthropicLLM @@ -16,7 +16,7 @@ def __init__(self, *args, **kwargs): async def _apply_prompt_provider_specific( self, - multipart_messages: List["PromptMessageExtended"], + multipart_messages: list["PromptMessageExtended"], request_params: RequestParams | None = None, tools=None, is_template: bool = False, @@ -25,8 +25,8 @@ async def _apply_prompt_provider_specific( return multipart_messages[-1] def _convert_extended_messages_to_provider( - self, messages: List[PromptMessageExtended] - ) -> List[Any]: + self, messages: list[PromptMessageExtended] + ) -> list[Any]: """Convert messages to provider format - stub returns empty list""" return [] diff --git a/tests/unit/fast_agent/llm/test_structured.py b/tests/unit/fast_agent/llm/test_structured.py index d8253adc9..44d8c5a0a 100644 --- a/tests/unit/fast_agent/llm/test_structured.py +++ b/tests/unit/fast_agent/llm/test_structured.py @@ -1,4 +1,4 @@ -from typing import List, Literal +from typing import Literal import pytest from pydantic import BaseModel @@ -15,7 +15,7 @@ class StructuredResponseCategory(BaseModel): class StructuredResponse(BaseModel): - categories: List[StructuredResponseCategory] + categories: list[StructuredResponseCategory] @pytest.mark.asyncio diff --git a/tests/unit/fast_agent/mcp/test_ui_mixin.py b/tests/unit/fast_agent/mcp/test_ui_mixin.py index a6da081e9..7345aad0e 100644 --- a/tests/unit/fast_agent/mcp/test_ui_mixin.py +++ b/tests/unit/fast_agent/mcp/test_ui_mixin.py @@ -1,6 +1,5 @@ """Tests for the MCP UI Mixin.""" -from typing import List import pytest from mcp.types import CallToolResult, EmbeddedResource, TextContent, TextResourceContents @@ -42,8 +41,8 @@ async def run_tools(self, request): async def show_assistant_message( self, message, - bottom_items: List[str] | None = None, - highlight_items: str | List[str] | None = None, + bottom_items: list[str] | None = None, + highlight_items: str | list[str] | None = None, max_item_length: int | None = None, name: str | None = None, model: str | None = None, diff --git a/tests/unit/fast_agent/tools/test_shell_runtime.py b/tests/unit/fast_agent/tools/test_shell_runtime.py index 35df98065..35e0224d3 100644 --- a/tests/unit/fast_agent/tools/test_shell_runtime.py +++ b/tests/unit/fast_agent/tools/test_shell_runtime.py @@ -5,7 +5,7 @@ import subprocess from contextlib import contextmanager from pathlib import Path -from typing import Any, Dict, Tuple +from typing import Any import pytest @@ -57,15 +57,15 @@ def _no_progress(): def _setup_runtime( - monkeypatch: pytest.MonkeyPatch, runtime_info: Dict[str, str] -) -> Tuple[ShellRuntime, DummyProcess, Dict[str, Any]]: + monkeypatch: pytest.MonkeyPatch, runtime_info: dict[str, str] +) -> tuple[ShellRuntime, DummyProcess, dict[str, Any]]: logger = logging.getLogger("shell-runtime-test") runtime = ShellRuntime(activation_reason="test", logger=logger) runtime.runtime_info = lambda: runtime_info # type: ignore[assignment] runtime.working_directory = lambda: Path(".") # type: ignore[assignment] dummy_process = DummyProcess() - captured: Dict[str, Any] = {} + captured: dict[str, Any] = {} async def fake_exec(*args, **kwargs): captured["exec_args"] = args