diff --git a/agent.py b/agent.py deleted file mode 100644 index 2f75147e1..000000000 --- a/agent.py +++ /dev/null @@ -1,27 +0,0 @@ -import asyncio - -from fast_agent import FastAgent - -# Create the application -fast = FastAgent("fast-agent example") - - -default_instruction = """You are a helpful AI Agent. - -{{serverInstructions}} - -The current date is {{currentDate}}.""" - - -# Define the agent -@fast.agent(instruction=default_instruction) -async def main(): - # use the --model command line switch or agent arguments to change model - async with fast.run() as agent: - await agent.send("tabulate the top 50 airports and include a small fact about the city it is closest to") - await agent.interactive() - await agent.send("write 10 demonstration typescript programs of around 50 lines each demonstrating different transport features") - await agent.interactive() - -if __name__ == "__main__": - asyncio.run(main()) diff --git a/src/fast_agent/agents/llm_decorator.py b/src/fast_agent/agents/llm_decorator.py index 768e08c32..6544eb4a1 100644 --- a/src/fast_agent/agents/llm_decorator.py +++ b/src/fast_agent/agents/llm_decorator.py @@ -42,7 +42,11 @@ from pydantic import BaseModel from fast_agent.agents.agent_types import AgentConfig, AgentType -from fast_agent.constants import FAST_AGENT_ERROR_CHANNEL, FAST_AGENT_REMOVED_METADATA_CHANNEL +from fast_agent.constants import ( + CONTROL_MESSAGE_SAVE_HISTORY, + FAST_AGENT_ERROR_CHANNEL, + FAST_AGENT_REMOVED_METADATA_CHANNEL, +) from fast_agent.context import Context from fast_agent.core.logging.logger import get_logger from fast_agent.interfaces import ( @@ -127,6 +131,17 @@ class RemovedContentSummary: message: str +@dataclass +class _CallContext: + """Internal helper for assembling an LLM call.""" + + full_history: List[PromptMessageExtended] + call_params: RequestParams | None + persist_history: bool + sanitized_messages: List[PromptMessageExtended] + summary: RemovedContentSummary | None + + class LlmDecorator(StreamingAgentMixin, AgentProtocol): """ A pure delegation wrapper around LlmAgent instances. @@ -150,6 +165,9 @@ def __init__( self._tracer = trace.get_tracer(__name__) self.instruction = self.config.instruction + # Agent-owned conversation state (PromptMessageExtended only) + self._message_history: List[PromptMessageExtended] = [] + # Store the default request params from config self._default_request_params = self.config.default_request_params @@ -338,7 +356,16 @@ async def apply_prompt_template(self, prompt_result: GetPromptResult, prompt_nam Returns: String representation of the assistant's response if generated """ + from fast_agent.types import PromptMessageExtended + assert self._llm + + multipart_messages = PromptMessageExtended.parse_get_prompt_result(prompt_result) + for msg in multipart_messages: + msg.is_template = True + + self._message_history = [msg.model_copy(deep=True) for msg in multipart_messages] + return await self._llm.apply_prompt_template(prompt_result, prompt_name) async def apply_prompt( @@ -375,6 +402,11 @@ def clear(self, *, clear_prompts: bool = False) -> None: if not self._llm: return self._llm.clear(clear_prompts=clear_prompts) + if clear_prompts: + self._message_history = [] + else: + template_prefix = self._template_prefix_messages() + self._message_history = [msg.model_copy(deep=True) for msg in template_prefix] async def structured( self, @@ -445,9 +477,16 @@ async def _generate_with_summary( tools: List[Tool] | None = None, ) -> Tuple[PromptMessageExtended, RemovedContentSummary | None]: assert self._llm, "LLM is not attached" - sanitized_messages, summary = self._sanitize_messages_for_llm(messages) - response = await self._llm.generate(sanitized_messages, request_params, tools) - return response, summary + call_ctx = self._prepare_llm_call(messages, request_params) + + response = await self._llm.generate( + call_ctx.full_history, call_ctx.call_params, tools + ) + + if call_ctx.persist_history: + self._persist_history(call_ctx.sanitized_messages, response) + + return response, call_ctx.summary async def _structured_with_summary( self, @@ -456,9 +495,68 @@ async def _structured_with_summary( request_params: RequestParams | None = None, ) -> Tuple[Tuple[ModelT | None, PromptMessageExtended], RemovedContentSummary | None]: assert self._llm, "LLM is not attached" + call_ctx = self._prepare_llm_call(messages, request_params) + + structured_result = await self._llm.structured( + call_ctx.full_history, model, call_ctx.call_params + ) + + if call_ctx.persist_history: + try: + _, assistant_message = structured_result + self._persist_history(call_ctx.sanitized_messages, assistant_message) + except Exception: + pass + return structured_result, call_ctx.summary + + def _prepare_llm_call( + self, messages: List[PromptMessageExtended], request_params: RequestParams | None = None + ) -> _CallContext: + """Normalize template/history handling for both generate and structured.""" sanitized_messages, summary = self._sanitize_messages_for_llm(messages) - structured_result = await self._llm.structured(sanitized_messages, model, request_params) - return structured_result, summary + final_request_params = self._llm.get_request_params(request_params) + + use_history = final_request_params.use_history if final_request_params else True + call_params = final_request_params.model_copy() if final_request_params else None + if call_params and not call_params.use_history: + call_params.use_history = True + + base_history = self._message_history if use_history else self._template_prefix_messages() + full_history = [msg.model_copy(deep=True) for msg in base_history] + full_history.extend(sanitized_messages) + + return _CallContext( + full_history=full_history, + call_params=call_params, + persist_history=use_history, + sanitized_messages=sanitized_messages, + summary=summary, + ) + + def _persist_history( + self, + sanitized_messages: List[PromptMessageExtended], + assistant_message: PromptMessageExtended, + ) -> None: + """Persist the last turn unless explicitly disabled by control text.""" + if not sanitized_messages: + return + if sanitized_messages[-1].first_text().startswith(CONTROL_MESSAGE_SAVE_HISTORY): + return + + history_messages = [self._strip_removed_metadata(msg) for msg in sanitized_messages] + self._message_history.extend(history_messages) + self._message_history.append(assistant_message) + + @staticmethod + def _strip_removed_metadata(message: PromptMessageExtended) -> PromptMessageExtended: + """Remove per-turn removed-content metadata before persisting to history.""" + msg_copy = message.model_copy(deep=True) + if msg_copy.channels and FAST_AGENT_REMOVED_METADATA_CHANNEL in msg_copy.channels: + channels = dict(msg_copy.channels) + channels.pop(FAST_AGENT_REMOVED_METADATA_CHANNEL, None) + msg_copy.channels = channels if channels else None + return msg_copy def _sanitize_messages_for_llm( self, messages: List[PromptMessageExtended] @@ -761,9 +859,27 @@ def message_history(self) -> List[PromptMessageExtended]: Returns: List of PromptMessageExtended objects representing the conversation history """ - if self._llm: - return self._llm.message_history - return [] + return self._message_history + + @property + def template_messages(self) -> List[PromptMessageExtended]: + """ + Return the template prefix of the message history. + + Templates are identified via the is_template flag and are expected to + appear as a contiguous prefix of the history. + """ + return [msg.model_copy(deep=True) for msg in self._template_prefix_messages()] + + def _template_prefix_messages(self) -> List[PromptMessageExtended]: + """Return the leading messages marked as templates (non-copy).""" + prefix: List[PromptMessageExtended] = [] + for msg in self._message_history: + if msg.is_template: + prefix.append(msg) + else: + break + return prefix def pop_last_message(self) -> PromptMessageExtended | None: """Remove and return the most recent message from the conversation history.""" diff --git a/src/fast_agent/agents/mcp_agent.py b/src/fast_agent/agents/mcp_agent.py index aeb345b56..6d687bf96 100644 --- a/src/fast_agent/agents/mcp_agent.py +++ b/src/fast_agent/agents/mcp_agent.py @@ -1320,9 +1320,8 @@ def message_history(self) -> List[PromptMessageExtended]: Returns: List of PromptMessageExtended objects representing the conversation history """ - if self._llm: - return self._llm.message_history - return [] + # Conversation history is maintained at the agent layer; LLM history is diagnostic only. + return super().message_history @property def usage_accumulator(self) -> Optional["UsageAccumulator"]: diff --git a/src/fast_agent/constants.py b/src/fast_agent/constants.py index 9b4b028af..812d1bde0 100644 --- a/src/fast_agent/constants.py +++ b/src/fast_agent/constants.py @@ -31,3 +31,5 @@ {{env}} The current date is {{currentDate}}.""" + +CONTROL_MESSAGE_SAVE_HISTORY = "***SAVE_HISTORY" diff --git a/src/fast_agent/llm/fastagent_llm.py b/src/fast_agent/llm/fastagent_llm.py index 0de9b5a06..74599eb53 100644 --- a/src/fast_agent/llm/fastagent_llm.py +++ b/src/fast_agent/llm/fastagent_llm.py @@ -26,7 +26,11 @@ from openai.lib._parsing import type_to_response_format_param as _type_to_response_format from pydantic_core import from_json -from fast_agent.constants import DEFAULT_MAX_ITERATIONS, FAST_AGENT_TIMING +from fast_agent.constants import ( + CONTROL_MESSAGE_SAVE_HISTORY, + DEFAULT_MAX_ITERATIONS, + FAST_AGENT_TIMING, +) from fast_agent.context_dependent import ContextDependent from fast_agent.core.logging.logger import get_logger from fast_agent.core.prompt import Prompt @@ -133,9 +137,6 @@ def __init__( # memory contains provider specific API types. self.history: Memory[MessageParamT] = SimpleMemory[MessageParamT]() - self._message_history: List[PromptMessageExtended] = [] - self._template_messages: List[PromptMessageExtended] = [] - # Initialize the display component from fast_agent.ui.console_display import ConsoleDisplay @@ -203,7 +204,7 @@ async def generate( """ # TODO -- create a "fast-agent" control role rather than magic strings - if messages[-1].first_text().startswith("***SAVE_HISTORY"): + if messages[-1].first_text().startswith(CONTROL_MESSAGE_SAVE_HISTORY): parts: list[str] = messages[-1].first_text().split(" ", 1) if len(parts) > 1: filename: str = parts[1].strip() @@ -212,20 +213,21 @@ async def generate( timestamp = datetime.now().strftime("%y_%m_%d_%H_%M") filename = f"{timestamp}-conversation.json" - await self._save_history(filename) + await self._save_history(filename, messages) return Prompt.assistant(f"History saved to {filename}") - self._precall(messages) - # Store MCP metadata in context variable final_request_params = self.get_request_params(request_params) if final_request_params.mcp_metadata: _mcp_metadata_var.set(final_request_params.mcp_metadata) + # The caller supplies the full conversation to send + full_history = messages + # Track timing for this generation start_time = time.perf_counter() assistant_response: PromptMessageExtended = await self._apply_prompt_provider_specific( - messages, request_params, tools + full_history, request_params, tools ) end_time = time.perf_counter() duration_ms = round((end_time - start_time) * 1000, 2) @@ -244,12 +246,6 @@ async def generate( self.usage_accumulator.count_tools(len(assistant_response.tool_calls or {})) - # add generic error and termination reason handling/rollback - # Only append if it's not already the last message in history - # (this can happen when loading a saved history that ends with an assistant message) - if not self._message_history or self._message_history[-1] is not assistant_response: - self._message_history.append(assistant_response) - return assistant_response @abstractmethod @@ -295,8 +291,6 @@ async def structured( Tuple of (parsed model instance or None, assistant response message) """ - self._precall(messages) - # Store MCP metadata in context variable final_request_params = self.get_request_params(request_params) @@ -304,10 +298,12 @@ async def structured( if final_request_params.mcp_metadata: _mcp_metadata_var.set(final_request_params.mcp_metadata) + full_history = messages + # Track timing for this structured generation start_time = time.perf_counter() result, assistant_response = await self._apply_prompt_provider_specific_structured( - messages, model, request_params + full_history, model, request_params ) end_time = time.perf_counter() duration_ms = round((end_time - start_time) * 1000, 2) @@ -324,7 +320,6 @@ async def structured( channels[FAST_AGENT_TIMING] = [TextContent(type="text", text=json.dumps(timing_data))] assistant_response.channels = channels - self._message_history.append(assistant_response) return result, assistant_response @staticmethod @@ -405,14 +400,17 @@ def _prepare_structured_text(self, text: str) -> str: """Hook for subclasses to adjust structured output text before parsing.""" return text + def record_templates(self, templates: List[PromptMessageExtended]) -> None: + """Hook for providers that need template visibility (e.g., caching).""" + return + def _precall(self, multipart_messages: List[PromptMessageExtended]) -> None: """Pre-call hook to modify the message before sending it to the provider.""" - # Ensure all messages are PromptMessageExtended before extending history - self._message_history.extend(multipart_messages) + # No-op placeholder; history is managed by the agent def chat_turn(self) -> int: """Return the current chat turn number""" - return 1 + sum(1 for message in self._message_history if message.role == "assistant") + return 1 + len(self._usage_accumulator.turns) def prepare_provider_arguments( self, @@ -630,6 +628,37 @@ def _convert_prompt_messages(self, prompt_messages: List[PromptMessage]) -> List """ raise NotImplementedError("Must be implemented by subclass") + def _convert_to_provider_format( + self, messages: List[PromptMessageExtended] + ) -> List[MessageParamT]: + """ + Convert provided messages to provider-specific format. + Called fresh on EVERY API call - no caching. + + Args: + messages: List of PromptMessageExtended + + Returns: + List of provider-specific message objects + """ + return self._convert_extended_messages_to_provider(messages) + + @abstractmethod + def _convert_extended_messages_to_provider( + self, messages: List[PromptMessageExtended] + ) -> List[MessageParamT]: + """ + Convert PromptMessageExtended list to provider-specific format. + Must be implemented by each provider. + + Args: + messages: List of PromptMessageExtended objects + + Returns: + List of provider-specific message parameter objects + """ + raise NotImplementedError("Must be implemented by subclass") + async def show_prompt_loaded( self, prompt_name: str, @@ -685,20 +714,14 @@ async def apply_prompt_template(self, prompt_result: GetPromptResult, prompt_nam arguments=arguments, ) - # Convert to PromptMessageExtended objects + # Convert to PromptMessageExtended objects and delegate multipart_messages = PromptMessageExtended.parse_get_prompt_result(prompt_result) - # Store a local copy of template messages so we can retain them across clears - self._template_messages = [msg.model_copy(deep=True) for msg in multipart_messages] - - # Delegate to the provider-specific implementation result = await self._apply_prompt_provider_specific( multipart_messages, None, is_template=True ) - # Ensure message history always includes the stored template when applied - self._message_history = [msg.model_copy(deep=True) for msg in self._template_messages] return result.first_text() - async def _save_history(self, filename: str) -> None: + async def _save_history(self, filename: str, messages: List[PromptMessageExtended]) -> None: """ Save the Message History to a file in a format determined by the file extension. @@ -707,8 +730,15 @@ async def _save_history(self, filename: str) -> None: """ from fast_agent.mcp.prompt_serialization import save_messages + # Drop control messages like ***SAVE_HISTORY before persisting + filtered = [ + msg.model_copy(deep=True) + for msg in messages + if not msg.first_text().startswith(CONTROL_MESSAGE_SAVE_HISTORY) + ] + # Save messages using the unified save function that auto-detects format - save_messages(self._message_history, filename) + save_messages(filtered, filename) @property def message_history(self) -> List[PromptMessageExtended]: @@ -721,32 +751,16 @@ def message_history(self) -> List[PromptMessageExtended]: Returns: List of PromptMessageExtended objects representing the conversation history """ - return self._message_history + return [] def pop_last_message(self) -> PromptMessageExtended | None: """Remove and return the most recent message from the conversation history.""" - if not self._message_history: - return None - - removed = self._message_history.pop() - try: - self.history.pop() - except Exception: - # If provider-specific memory isn't available, ignore to avoid crashing UX - pass - return removed + return None def clear(self, *, clear_prompts: bool = False) -> None: """Reset stored message history while optionally retaining prompt templates.""" self.history.clear(clear_prompts=clear_prompts) - if clear_prompts: - self._template_messages = [] - self._message_history = [] - return - - # Restore message history to template messages only; new turns will append as normal - self._message_history = [msg.model_copy(deep=True) for msg in self._template_messages] def _api_key(self): if self._init_api_key: diff --git a/src/fast_agent/llm/internal/passthrough.py b/src/fast_agent/llm/internal/passthrough.py index 52d59be50..443cfc51f 100644 --- a/src/fast_agent/llm/internal/passthrough.py +++ b/src/fast_agent/llm/internal/passthrough.py @@ -81,6 +81,10 @@ async def _apply_prompt_provider_specific( self.history.extend(multipart_messages, is_prompt=is_template) last_message = multipart_messages[-1] + # If the caller already provided an assistant reply (e.g., history replay), return it as-is. + if last_message.role == "assistant": + return last_message + tool_calls: Dict[str, CallToolRequest] = {} stop_reason: LlmStopReason = LlmStopReason.END_TURN if self.is_tool_call(last_message): @@ -112,9 +116,14 @@ async def _apply_prompt_provider_specific( self._fixed_response, tool_calls=tool_calls, stop_reason=stop_reason ) else: - concatenated_content = "\n".join( - [message.all_text() for message in multipart_messages if "user" == message.role] - ) + # Walk backwards through messages concatenating while role is "user" + user_messages = [] + for message in reversed(multipart_messages): + if message.role != "user": + break + user_messages.append(message.all_text()) + concatenated_content = "\n".join(reversed(user_messages)) + result = Prompt.assistant( concatenated_content, tool_calls=tool_calls, @@ -133,5 +142,20 @@ async def _apply_prompt_provider_specific( return result + def _convert_extended_messages_to_provider( + self, messages: List[PromptMessageExtended] + ) -> List[Any]: + """ + Convert PromptMessageExtended list to provider format. + For PassthroughLLM, we don't actually make API calls, so this just returns empty list. + + Args: + messages: List of PromptMessageExtended objects + + Returns: + Empty list (passthrough doesn't use provider-specific messages) + """ + return [] + def is_tool_call(self, message: PromptMessageExtended) -> bool: return message.first_text().startswith(CALL_TOOL_INDICATOR) diff --git a/src/fast_agent/llm/memory.py b/src/fast_agent/llm/memory.py index 1e7f190be..88367ea71 100644 --- a/src/fast_agent/llm/memory.py +++ b/src/fast_agent/llm/memory.py @@ -7,6 +7,13 @@ class Memory(Protocol, Generic[MessageParamT]): """ Simple memory management for storing past interactions in-memory. + + IMPORTANT: As of the conversation history architecture refactor, + provider history is DIAGNOSTIC ONLY. Messages are generated fresh + from _message_history on each API call via _convert_to_provider_format(). + + The get() method should NOT be called by provider code for API calls. + It may still be used for debugging/inspection purposes. """ # TODO: saqadri - add checkpointing and other advanced memory capabilities @@ -86,13 +93,23 @@ def get(self, include_completion_history: bool = True) -> List[MessageParamT]: """ Get all messages in memory. + DEPRECATED: Provider history is now diagnostic only. This method returns + a diagnostic snapshot and should NOT be used for API calls. Messages for + API calls are generated fresh from _message_history via + _convert_to_provider_format(). + Args: include_history: If True, include regular history messages If False, only return prompt messages Returns: Combined list of prompt messages and optionally history messages + (for diagnostic/inspection purposes only) """ + # Note: We don't emit a warning here because this method is still + # legitimately used for diagnostic purposes and by some internal code. + # The important change is that provider completion methods no longer + # call this for API message construction. if include_completion_history: return self.prompt_messages + self.history else: diff --git a/src/fast_agent/llm/model_factory.py b/src/fast_agent/llm/model_factory.py index b89197d2a..9058e6037 100644 --- a/src/fast_agent/llm/model_factory.py +++ b/src/fast_agent/llm/model_factory.py @@ -136,7 +136,7 @@ class ModelFactory: "gpt-oss": "hf.openai/gpt-oss-120b", "gpt-oss-20b": "hf.openai/gpt-oss-20b", "glm": "hf.zai-org/GLM-4.6", - "qwen3": "hf.Qwen/Qwen3-Next-80B-A3B-Instruct", + "qwen3": "hf.Qwen/Qwen3-Next-80B-A3B-Instruct:together", "deepseek31": "hf.deepseek-ai/DeepSeek-V3.1", "kimithink": "hf.moonshotai/Kimi-K2-Thinking:together", } diff --git a/src/fast_agent/llm/provider/anthropic/cache_planner.py b/src/fast_agent/llm/provider/anthropic/cache_planner.py new file mode 100644 index 000000000..fcc9135d2 --- /dev/null +++ b/src/fast_agent/llm/provider/anthropic/cache_planner.py @@ -0,0 +1,57 @@ +from typing import List + +from fast_agent.mcp.prompt_message_extended import PromptMessageExtended + + +class AnthropicCachePlanner: + """Calculate where to apply Anthropic cache_control blocks.""" + + def __init__( + self, + walk_distance: int = 6, + max_conversation_blocks: int = 2, + max_total_blocks: int = 4, + ) -> None: + self.walk_distance = walk_distance + self.max_conversation_blocks = max_conversation_blocks + self.max_total_blocks = max_total_blocks + + def _template_prefix_count(self, messages: List[PromptMessageExtended]) -> int: + return sum(msg.is_template for msg in messages) + + def plan_indices( + self, + messages: List[PromptMessageExtended], + cache_mode: str, + system_cache_blocks: int = 0, + ) -> List[int]: + """Return message indices that should receive cache_control.""" + + if cache_mode == "off" or not messages: + return [] + + budget = max(0, self.max_total_blocks - system_cache_blocks) + if budget == 0: + return [] + + template_prefix = self._template_prefix_count(messages) + template_indices: List[int] = [] + + if cache_mode in ("prompt", "auto") and template_prefix: + template_indices = list(range(min(template_prefix, budget))) + budget -= len(template_indices) + + conversation_indices: List[int] = [] + if cache_mode == "auto" and budget > 0: + conv_count = max(0, len(messages) - template_prefix) + if conv_count >= self.walk_distance: + positions = [ + template_prefix + i + for i in range(self.walk_distance - 1, conv_count, self.walk_distance) + ] + + # Respect Anthropic limits and remaining budget + positions = positions[-self.max_conversation_blocks :] + conversation_indices = positions[:budget] + + return template_indices + conversation_indices diff --git a/src/fast_agent/llm/provider/anthropic/llm_anthropic.py b/src/fast_agent/llm/provider/anthropic/llm_anthropic.py index 25bb6e188..ea0008739 100644 --- a/src/fast_agent/llm/provider/anthropic/llm_anthropic.py +++ b/src/fast_agent/llm/provider/anthropic/llm_anthropic.py @@ -32,6 +32,7 @@ FastAgentLLM, RequestParams, ) +from fast_agent.llm.provider.anthropic.cache_planner import AnthropicCachePlanner from fast_agent.llm.provider.anthropic.multipart_converter_anthropic import ( AnthropicConverter, ) @@ -51,6 +52,8 @@ class AnthropicLLM(FastAgentLLM[MessageParam, Message]): + CONVERSATION_CACHE_WALK_DISTANCE = 6 + MAX_CONVERSATION_CACHE_BLOCKS = 2 # Anthropic-specific parameter exclusions ANTHROPIC_EXCLUDE_FIELDS = { FastAgentLLM.PARAM_MESSAGES, @@ -115,7 +118,7 @@ async def _prepare_tools( for tool in tools or [] ] - def _apply_system_cache(self, base_args: dict, cache_mode: str) -> None: + def _apply_system_cache(self, base_args: dict, cache_mode: str) -> int: """Apply cache control to system prompt if cache mode allows it.""" system_content: SystemParam | None = base_args.get("system") @@ -130,41 +133,31 @@ def _apply_system_cache(self, base_args: dict, cache_mode: str) -> None: logger.debug( "Applied cache_control to system prompt (caches tools+system in one block)" ) + return 1 # If it's already a list (shouldn't happen in current flow but type-safe) elif isinstance(system_content, list): logger.debug("System prompt already in list format") else: logger.debug(f"Unexpected system prompt type: {type(system_content)}") - async def _apply_conversation_cache(self, messages: List[MessageParam], cache_mode: str) -> int: - """Apply conversation caching if in auto mode. Returns number of cache blocks applied.""" - applied_count = 0 - if cache_mode == "auto" and self.history.should_apply_conversation_cache(): - cache_updates = self.history.get_conversation_cache_updates() + return 0 - # Remove cache control from old positions - if cache_updates["remove"]: - self.history.remove_cache_control_from_messages(messages, cache_updates["remove"]) - logger.debug( - f"Removed conversation cache_control from positions {cache_updates['remove']}" - ) + @staticmethod + def _apply_cache_control_to_message(message: MessageParam) -> bool: + """Apply cache control to the last content block of a message.""" + if not isinstance(message, dict) or "content" not in message: + return False - # Add cache control to new positions - if cache_updates["add"]: - applied_count = self.history.add_cache_control_to_messages( - messages, cache_updates["add"] - ) - if applied_count > 0: - self.history.apply_conversation_cache_updates(cache_updates) - logger.debug( - f"Applied conversation cache_control to positions {cache_updates['add']} ({applied_count} blocks)" - ) - else: - logger.debug( - f"Failed to apply conversation cache_control to positions {cache_updates['add']}" - ) + content_list = message["content"] + if not isinstance(content_list, list) or not content_list: + return False - return applied_count + for content_block in reversed(content_list): + if isinstance(content_block, dict): + content_block["cache_control"] = {"type": "ephemeral"} + return True + + return False def _is_structured_output_request(self, tool_uses: List[Any]) -> bool: """ @@ -454,6 +447,32 @@ def _stream_failure_response(self, error: APIError, model_name: str) -> PromptMe stop_reason=LlmStopReason.ERROR, ) + def _build_request_messages( + self, + params: RequestParams, + message_param: MessageParam, + pre_messages: List[MessageParam] | None = None, + history: List[PromptMessageExtended] | None = None, + ) -> List[MessageParam]: + """ + Build the list of Anthropic message parameters for the next request. + + Ensures that the current user message is only included once when history + is enabled, which prevents duplicate tool_result blocks from being sent. + """ + messages: List[MessageParam] = list(pre_messages) if pre_messages else [] + + history_messages: List[MessageParam] = [] + if params.use_history and history: + history_messages = self._convert_to_provider_format(history) + messages.extend(history_messages) + + include_current = not params.use_history or not history_messages + if include_current: + messages.append(message_param) + + return messages + async def _anthropic_completion( self, message_param, @@ -461,6 +480,8 @@ async def _anthropic_completion( structured_model: Type[ModelT] | None = None, tools: List[Tool] | None = None, pre_messages: List[MessageParam] | None = None, + history: List[PromptMessageExtended] | None = None, + current_extended: PromptMessageExtended | None = None, ) -> PromptMessageExtended: """ Process a query using an LLM and available tools. @@ -474,18 +495,14 @@ async def _anthropic_completion( try: anthropic = AsyncAnthropic(api_key=api_key, base_url=base_url) - messages: List[MessageParam] = list(pre_messages) if pre_messages else [] params = self.get_request_params(request_params) + messages = self._build_request_messages(params, message_param, pre_messages, history=history) except AuthenticationError as e: raise ProviderKeyError( "Invalid Anthropic API key", "The configured Anthropic API key was rejected.\nPlease check that your API key is valid and not expired.", ) from e - # Always include prompt messages, but only include conversation history if enabled - messages.extend(self.history.get(include_completion_history=params.use_history)) - messages.append(message_param) # message_param is the current user turn - # Get cache mode configuration cache_mode = self._get_cache_mode() logger.debug(f"Anthropic cache_mode: {cache_mode}") @@ -521,20 +538,25 @@ async def _anthropic_completion( ) # Apply cache control to system prompt AFTER merging arguments - self._apply_system_cache(arguments, cache_mode) - - # Apply conversation caching - applied_count = await self._apply_conversation_cache(messages, cache_mode) - - # Verify we don't exceed Anthropic's 4 cache block limit - if applied_count > 0: - total_cache_blocks = applied_count - if cache_mode != "off" and arguments["system"]: - total_cache_blocks += 1 # tools+system cache block - if total_cache_blocks > 4: - logger.warning( - f"Total cache blocks ({total_cache_blocks}) exceeds Anthropic limit of 4" - ) + system_cache_applied = self._apply_system_cache(arguments, cache_mode) + + # Apply cache_control markers using planner + planner = AnthropicCachePlanner( + self.CONVERSATION_CACHE_WALK_DISTANCE, self.MAX_CONVERSATION_CACHE_BLOCKS + ) + plan_messages: List[PromptMessageExtended] = [] + include_current = not params.use_history or not history + if params.use_history and history: + plan_messages.extend(history) + if include_current and current_extended: + plan_messages.append(current_extended) + + cache_indices = planner.plan_indices( + plan_messages, cache_mode=cache_mode, system_cache_blocks=system_cache_applied + ) + for idx in cache_indices: + if 0 <= idx < len(messages): + self._apply_cache_control_to_message(messages[idx]) logger.debug(f"{arguments}") # Use streaming API with helper @@ -607,13 +629,9 @@ async def _anthropic_completion( else: tool_calls = self._build_tool_calls_dict(tool_uses) - # Only save the new conversation messages to history if use_history is true - # Keep the prompt messages separate - if params.use_history: - # Get current prompt messages - prompt_messages = self.history.get(include_completion_history=False) - new_messages = messages[len(prompt_messages) :] - self.history.set(new_messages) + # Update diagnostic snapshot (never read again) + # This provides a snapshot of what was sent to the provider for debugging + self.history.set(messages) self._log_chat_finished(model=model) @@ -628,50 +646,25 @@ async def _apply_prompt_provider_specific( tools: List[Tool] | None = None, is_template: bool = False, ) -> PromptMessageExtended: - # Effective params for this turn - params = self.get_request_params(request_params) - + """ + Provider-specific prompt application. + Templates are handled by the agent; messages already include them. + """ # Check the last message role last_message = multipart_messages[-1] - # Add all previous messages to history (or all messages if last is from assistant) - messages_to_add = ( - multipart_messages[:-1] if last_message.role == "user" else multipart_messages - ) - converted: List[MessageParam] = [] - - # Get cache mode configuration - cache_mode = self._get_cache_mode() - - for msg in messages_to_add: - anthropic_msg = AnthropicConverter.convert_to_anthropic(msg) - - # Apply caching to template messages if cache_mode is "prompt" or "auto" - if is_template and cache_mode in ["prompt", "auto"] and anthropic_msg.get("content"): - content_list = anthropic_msg["content"] - if isinstance(content_list, list) and content_list: - # Apply cache control to the last content block - last_block = content_list[-1] - if isinstance(last_block, dict): - last_block["cache_control"] = {"type": "ephemeral"} - logger.debug( - f"Applied cache_control to template message with role {anthropic_msg.get('role')}" - ) - - converted.append(anthropic_msg) - - # Persist prior only when history is enabled; otherwise inline for this call - pre_messages: List[MessageParam] | None = None - if params.use_history: - self.history.extend(converted, is_prompt=is_template) - else: - pre_messages = converted - if last_message.role == "user": logger.debug("Last message in prompt is from user, generating assistant response") message_param = AnthropicConverter.convert_to_anthropic(last_message) + # No need to pass pre_messages - conversion happens in _anthropic_completion + # via _convert_to_provider_format() return await self._anthropic_completion( - message_param, request_params, tools=tools, pre_messages=pre_messages + message_param, + request_params, + tools=tools, + pre_messages=None, + history=multipart_messages, + current_extended=last_message, ) else: # For assistant messages: Return the last message content as text @@ -684,30 +677,27 @@ async def _apply_prompt_provider_specific_structured( model: Type[ModelT], request_params: RequestParams | None = None, ) -> Tuple[ModelT | None, PromptMessageExtended]: # noqa: F821 + """ + Provider-specific structured output implementation. + Note: Message history is managed by base class and converted via + _convert_to_provider_format() on each call. + """ request_params = self.get_request_params(request_params) # Check the last message role last_message = multipart_messages[-1] - # Add all previous messages to history (or all messages if last is from assistant) - messages_to_add = ( - multipart_messages[:-1] if last_message.role == "user" else multipart_messages - ) - converted = [] - - for msg in messages_to_add: - anthropic_msg = AnthropicConverter.convert_to_anthropic(msg) - converted.append(anthropic_msg) - - self.history.extend(converted, is_prompt=False) - if last_message.role == "user": logger.debug("Last message in prompt is from user, generating structured response") message_param = AnthropicConverter.convert_to_anthropic(last_message) # Call _anthropic_completion with the structured model result: PromptMessageExtended = await self._anthropic_completion( - message_param, request_params, structured_model=model + message_param, + request_params, + structured_model=model, + history=multipart_messages, + current_extended=last_message, ) for content in result.content: @@ -727,6 +717,21 @@ async def _apply_prompt_provider_specific_structured( logger.debug("Last message in prompt is from assistant, returning it directly") return None, last_message + def _convert_extended_messages_to_provider( + self, messages: List[PromptMessageExtended] + ) -> List[MessageParam]: + """ + Convert PromptMessageExtended list to Anthropic MessageParam format. + This is called fresh on every API call from _convert_to_provider_format(). + + Args: + messages: List of PromptMessageExtended objects + + Returns: + List of Anthropic MessageParam objects + """ + return [AnthropicConverter.convert_to_anthropic(msg) for msg in messages] + @classmethod def convert_message_to_message_param(cls, message: Message, **kwargs) -> MessageParam: """Convert a response object to an input parameter object to allow LLM calls to be chained.""" diff --git a/src/fast_agent/llm/provider/anthropic/multipart_converter_anthropic.py b/src/fast_agent/llm/provider/anthropic/multipart_converter_anthropic.py index f9dccf4fa..f9b880e81 100644 --- a/src/fast_agent/llm/provider/anthropic/multipart_converter_anthropic.py +++ b/src/fast_agent/llm/provider/anthropic/multipart_converter_anthropic.py @@ -1,3 +1,4 @@ +import re from typing import List, Sequence, Union from anthropic.types import ( @@ -81,6 +82,7 @@ def convert_to_anthropic(multipart_msg: PromptMessageExtended) -> MessageParam: # legally include corresponding tool_result blocks. if role == "assistant" and multipart_msg.tool_calls: for tool_use_id, req in multipart_msg.tool_calls.items(): + sanitized_id = AnthropicConverter._sanitize_tool_id(tool_use_id) name = None args = None try: @@ -94,7 +96,7 @@ def convert_to_anthropic(multipart_msg: PromptMessageExtended) -> MessageParam: all_content_blocks.append( ToolUseBlockParam( type="tool_use", - id=tool_use_id, + id=sanitized_id, name=name or "unknown_tool", input=args or {}, ) @@ -404,6 +406,7 @@ def create_tool_results_message( content_blocks = [] for tool_use_id, result in tool_results: + sanitized_id = AnthropicConverter._sanitize_tool_id(tool_use_id) # Process each tool result tool_result_blocks = [] @@ -427,7 +430,7 @@ def create_tool_results_message( content_blocks.append( ToolResultBlockParam( type="tool_result", - tool_use_id=tool_use_id, + tool_use_id=sanitized_id, content=tool_result_blocks, is_error=result.isError, ) @@ -437,7 +440,7 @@ def create_tool_results_message( content_blocks.append( ToolResultBlockParam( type="tool_result", - tool_use_id=tool_use_id, + tool_use_id=sanitized_id, content=[TextBlockParam(type="text", text="[No content in tool result]")], is_error=result.isError, ) @@ -446,3 +449,14 @@ def create_tool_results_message( # All content is now included within the tool_result block. return MessageParam(role="user", content=content_blocks) + + @staticmethod + def _sanitize_tool_id(tool_id: str | None) -> str: + """ + Anthropic tool_use ids must match ^[a-zA-Z0-9_-]+$. + Clean any other characters to underscores and provide a stable fallback. + """ + if not tool_id: + return "tool" + cleaned = re.sub(r"[^a-zA-Z0-9_-]", "_", tool_id) + return cleaned or "tool" diff --git a/src/fast_agent/llm/provider/bedrock/llm_bedrock.py b/src/fast_agent/llm/provider/bedrock/llm_bedrock.py index 8b22e477c..cb1795472 100644 --- a/src/fast_agent/llm/provider/bedrock/llm_bedrock.py +++ b/src/fast_agent/llm/provider/bedrock/llm_bedrock.py @@ -19,6 +19,7 @@ from fast_agent.event_progress import ProgressAction from fast_agent.interfaces import ModelT from fast_agent.llm.fastagent_llm import FastAgentLLM +from fast_agent.llm.provider.bedrock.multipart_converter_bedrock import BedrockConverter from fast_agent.llm.provider_types import Provider from fast_agent.llm.usage_tracking import TurnUsage from fast_agent.types import PromptMessageExtended, RequestParams @@ -279,6 +280,25 @@ def _get_bedrock_runtime_client(self): ) from e return self._bedrock_runtime_client + def _convert_extended_messages_to_provider( + self, messages: List[PromptMessageExtended] + ) -> List[BedrockMessageParam]: + """ + Convert PromptMessageExtended list to Bedrock BedrockMessageParam format. + This is called fresh on every API call from _convert_to_provider_format(). + + Args: + messages: List of PromptMessageExtended objects + + Returns: + List of Bedrock BedrockMessageParam objects + """ + converted: List[BedrockMessageParam] = [] + for msg in messages: + bedrock_msg = BedrockConverter.convert_to_bedrock(msg) + converted.append(bedrock_msg) + return converted + def _build_tool_name_mapping( self, tools: "ListToolsResult", name_policy: ToolNamePolicy ) -> Dict[str, str]: @@ -1193,6 +1213,7 @@ async def _bedrock_completion( request_params: RequestParams | None = None, tools: List[Tool] | None = None, pre_messages: List[BedrockMessageParam] | None = None, + history: List[PromptMessageExtended] | None = None, ) -> PromptMessageExtended: """ Process a query using Bedrock and available tools. @@ -1216,10 +1237,11 @@ async def _bedrock_completion( f"Error accessing Bedrock: {error_msg}", ) from e - # Always include prompt messages, but only include conversation history - # if use_history is True - messages.extend(self.history.get(include_completion_history=params.use_history)) - messages.append(message_param) + # Convert supplied history/messages directly + if history: + messages.extend(self._convert_to_provider_format(history)) + else: + messages.append(message_param) # Get available tools (no resolver gating; fallback logic will decide wiring) tool_list = None @@ -1820,20 +1842,9 @@ async def _bedrock_completion( # Map stop reason to LlmStopReason mapped_stop_reason = self._map_bedrock_stop_reason(stop_reason) - # Update history - if params.use_history: - # Get current prompt messages - prompt_messages = self.history.get(include_completion_history=False) - - # Calculate new conversation messages (excluding prompts) - new_messages = messages[len(prompt_messages) :] - - # Remove system prompt from new messages if it was added - if (self.instruction or params.systemPrompt) and new_messages: - # System prompt is not added to messages list in Bedrock, so no need to remove it - pass - - self.history.set(new_messages) + # Update diagnostic snapshot (never read again) + # This provides a snapshot of what was sent to the provider for debugging + self.history.set(messages) self._log_chat_finished(model=model) @@ -1851,48 +1862,28 @@ async def _apply_prompt_provider_specific( tools: List[Tool] | None = None, is_template: bool = False, ) -> PromptMessageExtended: - """Apply Bedrock-specific prompt formatting.""" + """ + Provider-specific prompt application. + Templates are handled by the agent; messages already include them. + """ if not multipart_messages: return PromptMessageExtended(role="user", content=[]) # Check the last message role last_message = multipart_messages[-1] - # Add all previous messages to history (or all messages if last is from assistant) - # if the last message is a "user" inference is required - # if the last message is a "user" inference is required - messages_to_add = ( - multipart_messages[:-1] if last_message.role == "user" else multipart_messages - ) - converted = [] - for msg in messages_to_add: - # Convert each message to Bedrock message parameter format - bedrock_msg = self._convert_multipart_to_bedrock_message(msg) - converted.append(bedrock_msg) - - # Only persist prior messages when history is enabled; otherwise inline for this call - params = self.get_request_params(request_params) - pre_messages: List[BedrockMessageParam] | None = None - if params.use_history: - self.history.extend(converted, is_prompt=is_template) - else: - pre_messages = converted - if last_message.role == "assistant": # For assistant messages: Return the last message (no completion needed) return last_message - # For user messages with tool_results, we need to add the tool result message to the conversation - if last_message.tool_results: - # Convert the tool result message and use it as the final input - message_param = self._convert_multipart_to_bedrock_message(last_message) - else: - # Convert the last user message to Bedrock message parameter format - message_param = self._convert_multipart_to_bedrock_message(last_message) + # Convert the last user message to Bedrock message parameter format + message_param = BedrockConverter.convert_to_bedrock(last_message) - # Call the completion method with optional pre_messages for no-history mode + # Call the completion method + # No need to pass pre_messages - conversion happens in _bedrock_completion + # via _convert_to_provider_format() return await self._bedrock_completion( - message_param, request_params, tools, pre_messages=pre_messages + message_param, request_params, tools, pre_messages=None, history=multipart_messages ) def _generate_simplified_schema(self, model: Type[ModelT]) -> str: diff --git a/src/fast_agent/llm/provider/bedrock/multipart_converter_bedrock.py b/src/fast_agent/llm/provider/bedrock/multipart_converter_bedrock.py new file mode 100644 index 000000000..f567bd2c7 --- /dev/null +++ b/src/fast_agent/llm/provider/bedrock/multipart_converter_bedrock.py @@ -0,0 +1,84 @@ +from typing import Any, Dict + +from fast_agent.types import PromptMessageExtended + +# Bedrock message format types +BedrockMessageParam = Dict[str, Any] + + +class BedrockConverter: + """Converts MCP message types to Bedrock API format.""" + + @staticmethod + def convert_to_bedrock(multipart_msg: PromptMessageExtended) -> BedrockMessageParam: + """ + Convert a PromptMessageExtended message to Bedrock API format. + + This is a wrapper around the instance method _convert_multipart_to_bedrock_message + to provide a static interface similar to AnthropicConverter. + + Args: + multipart_msg: The PromptMessageExtended message to convert + + Returns: + A Bedrock API message parameter dictionary + """ + # Simple conversion without needing BedrockLLM instance + bedrock_msg = {"role": multipart_msg.role, "content": []} + + # Handle tool results first (if present) + if multipart_msg.tool_results: + import json + + from mcp.types import TextContent + + # Check if any tool ID indicates system prompt format + has_system_prompt_tools = any( + tool_id.startswith("system_prompt_") for tool_id in multipart_msg.tool_results.keys() + ) + + if has_system_prompt_tools: + # For system prompt models: format as human-readable text + tool_result_parts = [] + for tool_id, tool_result in multipart_msg.tool_results.items(): + result_text = "".join( + part.text for part in tool_result.content if isinstance(part, TextContent) + ) + result_payload = { + "tool_name": tool_id, + "status": "error" if tool_result.isError else "success", + "result": result_text, + } + tool_result_parts.append(json.dumps(result_payload)) + + if tool_result_parts: + full_result_text = f"Tool Results:\n{', '.join(tool_result_parts)}" + bedrock_msg["content"].append({"type": "text", "text": full_result_text}) + else: + # For Nova/Anthropic models: use structured tool_result format + for tool_id, tool_result in multipart_msg.tool_results.items(): + result_content_blocks = [] + if tool_result.content: + for part in tool_result.content: + if isinstance(part, TextContent): + result_content_blocks.append({"text": part.text}) + + if not result_content_blocks: + result_content_blocks.append({"text": "[No content in tool result]"}) + + bedrock_msg["content"].append( + { + "type": "tool_result", + "tool_use_id": tool_id, + "content": result_content_blocks, + "status": "error" if tool_result.isError else "success", + } + ) + + # Handle regular content + from mcp.types import TextContent + for content_item in multipart_msg.content: + if isinstance(content_item, TextContent): + bedrock_msg["content"].append({"type": "text", "text": content_item.text}) + + return bedrock_msg diff --git a/src/fast_agent/llm/provider/google/llm_google_native.py b/src/fast_agent/llm/provider/google/llm_google_native.py index 7cad7f0b7..cafbc523a 100644 --- a/src/fast_agent/llm/provider/google/llm_google_native.py +++ b/src/fast_agent/llm/provider/google/llm_google_native.py @@ -335,17 +335,8 @@ async def _google_completion( request_params = self.get_request_params(request_params=request_params) responses: List[ContentBlock] = [] - # Build conversation history from stored provider-specific messages - # and the provided message for this turn (no implicit conversion here). - # We store provider-native Content objects in history. - # Start with prompts + (optionally) accumulated conversation messages - base_history: List[types.Content] = self.history.get( - include_completion_history=request_params.use_history - ) - # Make a working copy and add the provided turn message(s) if present - conversation_history: List[types.Content] = list(base_history) - if message: - conversation_history.extend(message) + # Caller supplies the full set of messages to send (history + turn) + conversation_history: List[types.Content] = list(message or []) self.logger.debug(f"Google completion requested with messages: {conversation_history}") self._log_chat_progress(self.chat_turn(), model=request_params.model) @@ -473,13 +464,9 @@ async def _google_completion( else: stop_reason = self._map_finish_reason(getattr(candidate, "finish_reason", None)) - # 6. Persist conversation state to provider-native history (exclude prompt messages) - if request_params.use_history: - # History store separates prompt vs conversation messages; keep prompts as-is - prompt_messages = self.history.get(include_completion_history=False) - # messages after prompts are the true conversation history - new_messages = conversation_history[len(prompt_messages) :] - self.history.set(new_messages, is_prompt=False) + # Update diagnostic snapshot (never read again) + # This provides a snapshot of what was sent to the provider for debugging + self.history.set(conversation_history) self._log_chat_finished(model=request_params.model) # Use model from request_params return Prompt.assistant(*responses, stop_reason=stop_reason, tool_calls=tool_calls) @@ -494,31 +481,14 @@ async def _apply_prompt_provider_specific( is_template: bool = False, ) -> PromptMessageExtended: """ - Applies the prompt messages and potentially calls the LLM for completion. + Provider-specific prompt application. + Templates are handled by the agent; messages already include them. """ - request_params = self.get_request_params(request_params=request_params) # Determine the last message last_message = multipart_messages[-1] - # Add previous messages (excluding the last user message) to provider-native history - # If last is assistant, we add all messages and return it directly (no inference). - messages_to_add = ( - multipart_messages[:-1] if last_message.role == "user" else multipart_messages - ) - - if messages_to_add: - # Convert prior messages to google.genai Content - converted_prior = self._converter.convert_to_google_content(messages_to_add) - # Only persist prior context when history is enabled; otherwise inline later - if request_params.use_history: - self.history.extend(converted_prior, is_prompt=is_template) - else: - # Prepend prior context directly to the turn message list - # This keeps the single-turn chain intact without relying on provider memory - pass - if last_message.role == "assistant": # No generation required; the provided assistant message is the output return last_message @@ -532,7 +502,7 @@ async def _apply_prompt_provider_specific( # Map correlation IDs back to tool names using the last assistant tool_calls # found in our high-level message history id_to_name: Dict[str, str] = {} - for prev in reversed(self._message_history): + for prev in reversed(multipart_messages): if prev.role == "assistant" and prev.tool_calls: for call_id, call in prev.tool_calls.items(): try: @@ -557,19 +527,33 @@ async def _apply_prompt_provider_specific( # convert_to_google_content returns a list; preserve order after tool responses turn_messages.extend(user_contents) - # If not using provider history, include prior messages inline for this turn - if messages_to_add and not request_params.use_history: - prior_contents = self._converter.convert_to_google_content(messages_to_add) - turn_messages = prior_contents + turn_messages - # If we somehow have no provider-native parts, ensure we send an empty user content if not turn_messages: turn_messages.append(types.Content(role="user", parts=[types.Part.from_text("")])) - # Delegate to the native completion with explicit turn messages - return await self._google_completion( - turn_messages, request_params=request_params, tools=tools - ) + conversation_history: List[types.Content] = [] + if request_params.use_history and len(multipart_messages) > 1: + conversation_history.extend( + self._convert_to_provider_format(multipart_messages[:-1]) + ) + conversation_history.extend(turn_messages) + + return await self._google_completion(conversation_history, request_params=request_params, tools=tools) + + def _convert_extended_messages_to_provider( + self, messages: List[PromptMessageExtended] + ) -> List[types.Content]: + """ + Convert PromptMessageExtended list to Google types.Content format. + This is called fresh on every API call from _convert_to_provider_format(). + + Args: + messages: List of PromptMessageExtended objects + + Returns: + List of Google types.Content objects + """ + return self._converter.convert_to_google_content(messages) def _map_finish_reason(self, finish_reason: object) -> LlmStopReason: """Map Google finish reasons to LlmStopReason robustly.""" @@ -611,21 +595,14 @@ async def _apply_prompt_provider_specific_structured( request_params=None, ): """ - Handles structured output for Gemini models using response_schema and response_mime_type, - keeping provider-native (google.genai) history consistent with non-structured calls. + Provider-specific structured output implementation. + Note: Message history is managed by base class and converted via + _convert_to_provider_format() on each call. """ import json - # Determine the last message and add prior messages to provider-native history + # Determine the last message last_message = multipart_messages[-1] if multipart_messages else None - messages_to_add = ( - multipart_messages - if last_message and last_message.role == "assistant" - else multipart_messages[:-1] - ) - if messages_to_add: - converted_prior = self._converter.convert_to_google_content(messages_to_add) - self.history.extend(converted_prior, is_prompt=False) # If the last message is an assistant message, attempt to parse its JSON and return if last_message and last_message.role == "assistant": diff --git a/src/fast_agent/llm/provider/openai/llm_openai.py b/src/fast_agent/llm/provider/openai/llm_openai.py index 999c503d8..0534d70a2 100644 --- a/src/fast_agent/llm/provider/openai/llm_openai.py +++ b/src/fast_agent/llm/provider/openai/llm_openai.py @@ -698,8 +698,8 @@ async def _openai_completion( if system_prompt: messages.append(ChatCompletionSystemMessageParam(role="system", content=system_prompt)) - messages.extend(self.history.get(include_completion_history=request_params.use_history)) - if message is not None: + # The caller supplies the full history; convert it directly + if message: messages.extend(message) available_tools: List[ChatCompletionToolParam] | None = [ @@ -821,17 +821,9 @@ async def _openai_completion( stop_reason = LlmStopReason.SAFETY self.logger.debug(" Stopping because finish_reason is 'content_filter'") - if request_params.use_history: - # Get current prompt messages - prompt_messages = self.history.get(include_completion_history=False) - - # Calculate new conversation messages (excluding prompts) - new_messages = messages[len(prompt_messages) :] - - if system_prompt: - new_messages = new_messages[1:] - - self.history.set(new_messages) + # Update diagnostic snapshot (never read again) + # This provides a snapshot of what was sent to the provider for debugging + self.history.set(messages) self._log_chat_finished(model=self.default_request_params.model) @@ -896,41 +888,25 @@ async def _apply_prompt_provider_specific( tools: List[Tool] | None = None, is_template: bool = False, ) -> PromptMessageExtended: - # Determine effective params to respect use_history for this turn + """ + Provider-specific prompt application. + Templates are handled by the agent; messages already include them. + """ + # Determine effective params req_params = self.get_request_params(request_params) last_message = multipart_messages[-1] - # Prepare prior messages (everything before the last user message), or all if last is assistant - messages_to_add = ( - multipart_messages[:-1] if last_message.role == "user" else multipart_messages - ) - - converted_prior: List[ChatCompletionMessageParam] = [] - for msg in messages_to_add: - # convert_to_openai now returns a list of messages - converted_prior.extend(OpenAIConverter.convert_to_openai(msg)) - # If the last message is from the assistant, no inference required if last_message.role == "assistant": return last_message - # Convert the last user message - converted_last = OpenAIConverter.convert_to_openai(last_message) - if not converted_last: - # Fallback for empty conversion - converted_last = [{"role": "user", "content": ""}] - - # History-aware vs stateless turn construction - if req_params.use_history: - # Persist prior context to provider memory; send only the last message for this turn - self.history.extend(converted_prior, is_prompt=is_template) - turn_messages = converted_last - else: - # Do NOT persist; inline the full turn context to the provider call - turn_messages = converted_prior + converted_last + # Convert the supplied history/messages directly + converted_messages = self._convert_to_provider_format(multipart_messages) + if not converted_messages: + converted_messages = [{"role": "user", "content": ""}] - return await self._openai_completion(turn_messages, req_params, tools) + return await self._openai_completion(converted_messages, req_params, tools) def _prepare_api_request( self, messages, tools: List[ChatCompletionToolParam] | None, request_params: RequestParams @@ -963,6 +939,27 @@ def _prepare_api_request( ) return arguments + def _convert_extended_messages_to_provider( + self, messages: List[PromptMessageExtended] + ) -> List[ChatCompletionMessageParam]: + """ + Convert PromptMessageExtended list to OpenAI ChatCompletionMessageParam format. + This is called fresh on every API call from _convert_to_provider_format(). + + Args: + messages: List of PromptMessageExtended objects + + Returns: + List of OpenAI ChatCompletionMessageParam objects + """ + converted: List[ChatCompletionMessageParam] = [] + + for msg in messages: + # convert_to_openai returns a list of messages + converted.extend(OpenAIConverter.convert_to_openai(msg)) + + return converted + def adjust_schema(self, inputSchema: Dict) -> Dict: # return inputSchema if self.provider not in [Provider.OPENAI, Provider.AZURE]: diff --git a/src/fast_agent/mcp/prompt_message_extended.py b/src/fast_agent/mcp/prompt_message_extended.py index d444020e4..62a111f76 100644 --- a/src/fast_agent/mcp/prompt_message_extended.py +++ b/src/fast_agent/mcp/prompt_message_extended.py @@ -29,6 +29,7 @@ class PromptMessageExtended(BaseModel): tool_results: Dict[str, CallToolResult] | None = None channels: Mapping[str, Sequence[ContentBlock]] | None = None stop_reason: LlmStopReason | None = None + is_template: bool = False @classmethod def to_extended(cls, messages: List[PromptMessage]) -> List["PromptMessageExtended"]: diff --git a/src/fast_agent/mcp/prompts/prompt_load.py b/src/fast_agent/mcp/prompts/prompt_load.py index 950738cd3..305a36edf 100644 --- a/src/fast_agent/mcp/prompts/prompt_load.py +++ b/src/fast_agent/mcp/prompts/prompt_load.py @@ -9,6 +9,7 @@ from mcp.types import PromptMessage, TextContent from fast_agent.core.logging.logger import get_logger +from fast_agent.interfaces import AgentProtocol from fast_agent.mcp import mime_utils, resource_utils from fast_agent.mcp.prompts.prompt_template import ( PromptContent, @@ -156,3 +157,28 @@ def load_prompt_as_get_prompt_result(file: Path): # Convert to GetPromptResult (loses extended fields) return to_get_prompt_result(messages) + + +def load_history_into_agent(agent: AgentProtocol, file_path: Path) -> None: + """ + Load conversation history directly into agent without triggering LLM call. + + This function restores saved conversation state by directly setting the + agent's _message_history. No LLM API calls are made. + + Args: + agent: Agent instance to restore history into (FastAgentLLM or subclass) + file_path: Path to saved history file (JSON or template format) + + Note: + - The agent's history is cleared before loading + - Provider diagnostic history will be updated on the next API call + - Templates are NOT cleared by this function + """ + messages = load_prompt(file_path) + + # Direct restoration - no LLM call + agent.clear(clear_prompts=True) + agent.message_history.extend(messages) + + # Note: Provider diagnostic history will be updated on next API call diff --git a/src/fast_agent/mcp/server/agent_server.py b/src/fast_agent/mcp/server/agent_server.py index 77cbdaf5e..771de6c49 100644 --- a/src/fast_agent/mcp/server/agent_server.py +++ b/src/fast_agent/mcp/server/agent_server.py @@ -157,12 +157,13 @@ async def get_history_prompt(ctx: MCPContext) -> list: instance = await self._acquire_instance(ctx) agent = instance.app[agent_name] try: - if not hasattr(agent, "_llm") or agent._llm is None: + # Agent history is the authoritative source; LLM history is diagnostic only. + history = getattr(agent, "message_history", None) + if history is None: return [] # Convert the multipart message history to standard PromptMessages - multipart_history = agent._llm.message_history - prompt_messages = fast_agent.core.prompt.Prompt.from_multipart(multipart_history) + prompt_messages = fast_agent.core.prompt.Prompt.from_multipart(history) # In FastMCP, we need to return the raw list of messages return [{"role": msg.role, "content": msg.content} for msg in prompt_messages] diff --git a/src/fast_agent/ui/interactive_prompt.py b/src/fast_agent/ui/interactive_prompt.py index 4fb095b53..7ec5cedd6 100644 --- a/src/fast_agent/ui/interactive_prompt.py +++ b/src/fast_agent/ui/interactive_prompt.py @@ -17,6 +17,8 @@ from pathlib import Path from typing import TYPE_CHECKING, Any, Awaitable, Callable, Dict, List, Optional, Union, cast +from fast_agent.constants import CONTROL_MESSAGE_SAVE_HISTORY + if TYPE_CHECKING: from fast_agent.core.agent_app import AgentApp @@ -278,7 +280,9 @@ async def prompt_loop( rich_print(f"[green]History saved to {saved_path}[/green]") except Exception: # Fallback to magic string path for maximum compatibility - control = "***SAVE_HISTORY" + (f" {filename}" if filename else "") + control = CONTROL_MESSAGE_SAVE_HISTORY + ( + f" {filename}" if filename else "" + ) result = await send_func(control, agent) if result: rich_print(f"[green]{result}[/green]") @@ -291,22 +295,18 @@ async def prompt_loop( filename = command_dict.get("filename") try: - from fast_agent.mcp.prompts.prompt_load import load_prompt - - # Load the messages from the file - messages = load_prompt(Path(filename)) + from fast_agent.mcp.prompts.prompt_load import load_history_into_agent - # Get the agent object + # Get the agent object and its underlying LLM agent_obj = prompt_provider._agent(agent) - # Clear the agent's history first - agent_obj.clear() + # Load history directly without triggering LLM call + load_history_into_agent(agent_obj, Path(filename)) - # Load the messages into the agent's history - # We use generate() to properly process the loaded history - await agent_obj.generate(messages) - - rich_print(f"[green]History loaded from {filename}[/green]") + msg_count = len(agent_obj.message_history) + rich_print( + f"[green]Loaded {msg_count} messages from {filename}[/green]" + ) except FileNotFoundError: rich_print(f"[red]File not found: {filename}[/red]") except Exception as e: diff --git a/tests/e2e/history/test_history_save_load_e2e.py b/tests/e2e/history/test_history_save_load_e2e.py new file mode 100644 index 000000000..e017cd27a --- /dev/null +++ b/tests/e2e/history/test_history_save_load_e2e.py @@ -0,0 +1,155 @@ +import os +from contextlib import asynccontextmanager +from pathlib import Path +from typing import AsyncIterator + +import pytest +from mcp.types import CallToolResult, TextContent, Tool + +from fast_agent.agents.agent_types import AgentConfig +from fast_agent.agents.llm_agent import LlmAgent +from fast_agent.core import Core +from fast_agent.llm.model_factory import ModelFactory +from fast_agent.llm.request_params import RequestParams +from fast_agent.mcp.prompt_message_extended import PromptMessageExtended +from fast_agent.mcp.prompt_serialization import save_messages +from fast_agent.mcp.prompts.prompt_load import load_history_into_agent +from fast_agent.types.llm_stop_reason import LlmStopReason + +TEST_CONFIG_PATH = Path(__file__).resolve().parent.parent / "llm" / "fastagent.config.yaml" +DEFAULT_CREATE_MODELS = [ + "gpt-5-mini.minimal", + "haiku", + "gemini25", + "minimax", + "kimi", + "qwen3", + "glm", +] +DEFAULT_CHECK_MODELS = ["haiku", "kimigroq", "gpt-5-mini.minimal", "kimi", "qwen3", "glm"] +MAGIC_STRING = "MAGIC-ACCESS-PHRASE-9F1C" +MAGIC_TOOL = Tool( + name="fetch_magic_string", + description="Returns the daily passphrase when the assistant must call a tool.", + inputSchema={ + "type": "object", + "properties": { + "purpose": { + "type": "string", + "description": "Explain why you need the passphrase. Must always be supplied.", + } + }, + "required": ["purpose"], + }, +) + + +def _parse_model_list(raw: str | None, default: list[str]) -> list[str]: + if not raw: + return default + parsed = [value.strip() for value in raw.split(",") if value.strip()] + return parsed or default + + +CREATE_MODELS = _parse_model_list( + os.environ.get("FAST_AGENT_HISTORY_CREATE_MODELS"), DEFAULT_CREATE_MODELS +) +CHECK_MODELS = _parse_model_list( + os.environ.get("FAST_AGENT_HISTORY_CHECK_MODELS"), DEFAULT_CHECK_MODELS +) +MODEL_MATRIX = [(create, check) for create in CREATE_MODELS for check in CHECK_MODELS] +_HISTORY_CACHE: dict[str, Path] = {} + + +def _sanitize_model_name(model: str) -> str: + return model.replace("/", "_").replace(":", "_").replace(".", "-").replace(" ", "-").lower() + + +@asynccontextmanager +async def agent_session(model_name: str, label: str) -> AsyncIterator[LlmAgent]: + core = Core(settings=str(TEST_CONFIG_PATH)) + async with core.run(): + agent = LlmAgent(AgentConfig(label), core.context) + await agent.attach_llm(ModelFactory.create_factory(model_name)) + yield agent + + +async def _create_history(agent: LlmAgent) -> None: + greeting = await agent.generate( + "The following messages are part of a test of our LLM history functions. Let's start with a quick friendly greeting." + ) + assert greeting.stop_reason is LlmStopReason.END_TURN + + request = ( + "Call the fetch_magic_string tool to obtain today's secret passphrase. " + "You must call the tool before you can continue." + ) + tool_call = await agent.generate( + request, + tools=[MAGIC_TOOL], + request_params=RequestParams(maxTokens=300), + ) + assert tool_call.stop_reason is LlmStopReason.TOOL_USE + assert tool_call.tool_calls + tool_id = next(iter(tool_call.tool_calls.keys())) + + tool_result = CallToolResult(content=[TextContent(type="text", text=MAGIC_STRING)]) + user_tool_message = PromptMessageExtended( + role="user", + content=[ + TextContent( + type="text", + text="Here is the tool output. Read it carefully and repeat the passphrase verbatim.", + ) + ], + tool_results={tool_id: tool_result}, + ) + confirmation = await agent.generate(user_tool_message) + # confirmation_text = (confirmation.all_text() or "").lower() + assert LlmStopReason.END_TURN == confirmation.stop_reason + # assert MAGIC_STRING.lower() in confirmation_text + + wrap_up = await agent.generate( + "Great. Say something brief about keeping that passphrase safe so I know you stored it." + ) + assert wrap_up.stop_reason is LlmStopReason.END_TURN + + +async def _load_and_verify(agent: LlmAgent, history_file: Path) -> None: + load_history_into_agent(agent, history_file) + + follow_up = await agent.generate( + "Without inventing anything new, what exact passphrase did fetch_magic_string return earlier?" + ) + follow_text = (follow_up.all_text() or "").lower() + assert MAGIC_STRING.lower() in follow_text + + +async def _get_or_create_history_file(create_model: str, tmp_path_factory) -> Path: + """ + Create history once per creator model and reuse the saved file across check models. + """ + cached = _HISTORY_CACHE.get(create_model) + if cached and cached.exists(): + return cached + + history_dir = tmp_path_factory.mktemp(f"history-{_sanitize_model_name(create_model)}") + history_file = Path(history_dir) / "history.json" + + async with agent_session(create_model, f"history-create-{create_model}") as creator_agent: + await _create_history(creator_agent) + save_messages(creator_agent.message_history, history_file) + + assert history_file.exists() + _HISTORY_CACHE[create_model] = history_file + return history_file + + +@pytest.mark.e2e +@pytest.mark.asyncio +@pytest.mark.parametrize("create_model,check_model", MODEL_MATRIX) +async def test_history_survives_across_models(tmp_path_factory, create_model, check_model): + history_file = await _get_or_create_history_file(create_model, tmp_path_factory) + + async with agent_session(check_model, f"history-load-{check_model}") as checker_agent: + await _load_and_verify(checker_agent, history_file) diff --git a/tests/e2e/llm/test_llm_e2e.py b/tests/e2e/llm/test_llm_e2e.py index 444544d5a..70ded1d2d 100644 --- a/tests/e2e/llm/test_llm_e2e.py +++ b/tests/e2e/llm/test_llm_e2e.py @@ -56,7 +56,7 @@ def get_test_models(): "kimigroq", "kimi", "glm", - "qwen3", + "qwen3:together", "deepseek31", # "responses.gpt-5-mini", # "generic.qwen3:8b", @@ -228,7 +228,7 @@ async def test_tool_const_schema(llm_agent_setup, model_name): """Ensure providers accept tool schemas that include const constraints.""" agent = llm_agent_setup # should really refer to model db and extend all reasoning models :) - max_tokens = 500 if ("minimax" in model_name or "glm" in model_name) else 100 + max_tokens = 500 if ("minimax" in model_name or "glm" in model_name) else 200 result = await agent.generate( "call the const_mode tool so I can confirm the mode you must use.", tools=[_const_tool], diff --git a/tests/integration/api/test_prompt_listing.py b/tests/integration/api/test_prompt_listing.py index 1b6ec6b11..f5bf5f8da 100644 --- a/tests/integration/api/test_prompt_listing.py +++ b/tests/integration/api/test_prompt_listing.py @@ -73,6 +73,6 @@ async def agent_function(): # Verify the prompt was applied assert response, "No response from apply_prompt" - assert len(agent.test._llm.message_history) > 0 + assert len(agent.test.message_history) > 0 await agent_function() diff --git a/tests/integration/history-architecture/test_history_architecture.py b/tests/integration/history-architecture/test_history_architecture.py new file mode 100644 index 000000000..243950e88 --- /dev/null +++ b/tests/integration/history-architecture/test_history_architecture.py @@ -0,0 +1,178 @@ +""" +Integration tests for the new conversation history architecture. + +These tests verify that: +1. Agent message_history is the single source of truth +2. Provider history is diagnostic only (write-only) +3. load_history doesn't trigger LLM calls +4. Templates are correctly handled +""" + +import pytest + +from fast_agent.core.prompt import Prompt +from fast_agent.mcp.prompt_serialization import save_messages +from fast_agent.mcp.prompts.prompt_load import load_history_into_agent + + +@pytest.mark.integration +@pytest.mark.asyncio +async def test_load_history_no_llm_call(fast_agent, tmp_path): + """ + Verify that load_history_into_agent() does NOT trigger an LLM API call. + + This test ensures the bug fix where load_history previously called generate(). + """ + fast = fast_agent + + # Create a temporary history file with a simple conversation + history_file = tmp_path / "test_history.json" + messages = [ + Prompt.user("Hello"), + Prompt.assistant("Hi there!"), + Prompt.user("How are you?"), + ] + + # Save using the proper serialization format + save_messages(messages, str(history_file)) + + @fast.agent(model="passthrough") + async def agent_function(): + async with fast.run() as agent: + agent_obj = agent.default + + # Get initial message count + initial_count = len(agent_obj.message_history) + assert initial_count == 0, "Agent should start with no history" + + # Load history - this should NOT make an LLM call + load_history_into_agent(agent_obj, history_file) + + # Verify history was loaded + loaded_count = len(agent_obj.message_history) + assert loaded_count == 3, f"Expected 3 messages, got {loaded_count}" + + # Verify content + assert agent_obj.message_history[0].role == "user" + assert "Hello" in agent_obj.message_history[0].first_text() + assert agent_obj.message_history[1].role == "assistant" + assert "Hi there!" in agent_obj.message_history[1].first_text() + + await agent_function() + + +@pytest.mark.integration +@pytest.mark.asyncio +async def test_message_history_source_of_truth(fast_agent): + """ + Verify that _message_history is the single source of truth. + + Provider history should be diagnostic only and not read for API calls. + """ + fast = fast_agent + + @fast.agent(model="passthrough") + async def agent_function(): + async with fast.run() as agent: + agent_obj = agent.default + + # Start with empty histories + assert len(agent_obj.message_history) == 0 + + # Manually add a message to message_history + test_msg = Prompt.user("Test message") + agent_obj.message_history.append(test_msg) + + # Verify message is in message history + assert len(agent_obj.message_history) == 1 + assert agent_obj.message_history[0].first_text() == "Test message" + + # Provider history should still be empty (no API call yet) + # This verifies that message_history is independent of provider history + + await agent_function() + + +@pytest.mark.integration +@pytest.mark.asyncio +async def test_template_persistence_after_clear(fast_agent): + """ + Verify that template messages are preserved after clear() but removed after clear(clear_prompts=True). + """ + fast = fast_agent + + @fast.agent(model="passthrough") + async def agent_function(): + async with fast.run() as agent: + agent_obj = agent.default + + # Create template messages directly + template_msgs = [ + Prompt.user("You are a helpful assistant."), + Prompt.assistant("I understand."), + ] + template_msgs = [msg.model_copy(update={"is_template": True}) for msg in template_msgs] + agent_obj._message_history = [msg.model_copy(deep=True) for msg in template_msgs] + + # Verify template is loaded + assert len(agent_obj.template_messages) == 2 + assert len(agent_obj.message_history) == 2 + + # Add a user message + user_msg = Prompt.user("New message") + agent_obj._message_history.append(user_msg) + assert len(agent_obj.message_history) == 3 + + # Clear without clearing prompts + agent_obj.clear() + + # Templates should be restored, new message should be gone + assert len(agent_obj.message_history) == 2 + assert len(agent_obj.template_messages) == 2 + + # Add another message + agent_obj._message_history.append(user_msg) + assert len(agent_obj.message_history) == 3 + + # Clear with clear_prompts=True + agent_obj.clear(clear_prompts=True) + + # Everything should be gone + assert len(agent_obj.message_history) == 0 + assert len(agent_obj.template_messages) == 0 + + await agent_function() + + +@pytest.mark.integration +@pytest.mark.asyncio +async def test_provider_history_diagnostic_only(fast_agent): + """ + Verify that provider history (self.history) is diagnostic only. + + The provider should NOT read from self.history for API calls. + """ + fast = fast_agent + + @fast.agent(model="passthrough") + async def agent_function(): + async with fast.run() as agent: + agent_obj = agent.default + llm = agent_obj._llm + + # Start with empty histories + assert len(agent_obj.message_history) == 0 + + # Manually add a message to message_history + test_msg = Prompt.user("Test") + agent_obj.message_history.append(test_msg) + + # Verify it's in _message_history + assert len(agent_obj.message_history) == 1 + + # Provider history should still be empty (until an API call is made) + # This confirms that _message_history is independent of provider history + # and that provider history is only written to, not read from + assert len(llm.history.get()) == 0 + + await agent_function() diff --git a/tests/integration/prompt-state/test_load_prompt_templates.py b/tests/integration/prompt-state/test_load_prompt_templates.py index b6adbb678..a1693def4 100644 --- a/tests/integration/prompt-state/test_load_prompt_templates.py +++ b/tests/integration/prompt-state/test_load_prompt_templates.py @@ -5,6 +5,7 @@ import pytest from mcp.types import ImageContent +from fast_agent.constants import CONTROL_MESSAGE_SAVE_HISTORY from fast_agent.core.prompt import Prompt from fast_agent.mcp.prompts.prompt_load import ( load_prompt, @@ -33,7 +34,7 @@ async def agent_function(): # Use the "default" agent directly response = await agent.default.generate(loaded) assert "message 2" in agent.default.message_history[-4].first_text() - assert "message 3" in response.first_text() + assert "message 4" in response.first_text() await agent_function() @@ -84,7 +85,7 @@ async def agent_function(): os.remove("./simple.txt") await agent.send("hello") await agent.send("world") - await agent.send("***SAVE_HISTORY simple.txt") + await agent.send(f"{CONTROL_MESSAGE_SAVE_HISTORY} simple.txt") prompts: list[PromptMessageExtended] = load_prompt(Path("simple.txt")) assert 4 == len(prompts) @@ -120,7 +121,7 @@ async def agent_function(): await agent.send("world") # Save in JSON format (filename ends with .json) - await agent.send("***SAVE_HISTORY history.json") + await agent.send(f"{CONTROL_MESSAGE_SAVE_HISTORY} history.json") # Verify file exists assert os.path.exists("./history.json") @@ -183,7 +184,7 @@ async def agent_function(): await agent.test.generate([Prompt.user("good morning")]) await agent.test.generate([Prompt.user("what's in this image", Path("conv2_img.png"))]) - await agent.send("***SAVE_HISTORY multipart.json") + await agent.send(f"{CONTROL_MESSAGE_SAVE_HISTORY} multipart.json") prompts: list[PromptMessageExtended] = load_prompt(Path("./multipart.json")) assert 4 == len(prompts) diff --git a/tests/integration/workflow/chain/test_chain_passthrough.py b/tests/integration/workflow/chain/test_chain_passthrough.py index d810c0a80..a9ab6ffe1 100644 --- a/tests/integration/workflow/chain/test_chain_passthrough.py +++ b/tests/integration/workflow/chain/test_chain_passthrough.py @@ -42,7 +42,13 @@ async def chain_workflow(): # Renamed from main to avoid conflicts, and wrapped assert result == input_url result = await agent.topic_writer_cumulative.send("X") - # we expect the result to include tagged responses from all agents. - assert "X\nX\nX\nX" in result + # Expect tagged responses from all agents in cumulative output + assert "X" in result + assert "X" in result + assert "X\nX" in result + assert ( + "X\nX\nX\nX" + in result + ) await chain_workflow() # Call the inner function diff --git a/tests/unit/test_filesystem_runtime_integration.py b/tests/unit/acp/test_filesystem_runtime_integration.py similarity index 100% rename from tests/unit/test_filesystem_runtime_integration.py rename to tests/unit/acp/test_filesystem_runtime_integration.py diff --git a/tests/unit/fast_agent/agents/test_agent_history_binding.py b/tests/unit/fast_agent/agents/test_agent_history_binding.py new file mode 100644 index 000000000..32919a2f3 --- /dev/null +++ b/tests/unit/fast_agent/agents/test_agent_history_binding.py @@ -0,0 +1,68 @@ +import pytest +from mcp.types import TextContent + +from fast_agent.agents.agent_types import AgentConfig +from fast_agent.agents.llm_agent import LlmAgent +from fast_agent.core.prompt import Prompt +from fast_agent.llm.fastagent_llm import FastAgentLLM +from fast_agent.llm.provider_types import Provider +from fast_agent.llm.request_params import RequestParams +from fast_agent.types import PromptMessageExtended + + +class FakeLLM(FastAgentLLM[PromptMessageExtended, PromptMessageExtended]): + def __init__(self, **kwargs): + super().__init__(provider=Provider.FAST_AGENT, name="fake-llm", **kwargs) + self.last_messages: list[PromptMessageExtended] | None = None + + async def _apply_prompt_provider_specific( + self, + multipart_messages: list[PromptMessageExtended], + request_params: RequestParams | None = None, + tools=None, + is_template: bool = False, + ) -> PromptMessageExtended: + self.last_messages = list(multipart_messages) + return Prompt.assistant("ok") + + async def _apply_prompt_provider_specific_structured( + self, + multipart_messages: list[PromptMessageExtended], + model, + request_params: RequestParams | None = None, + ): + self.last_messages = list(multipart_messages) + return None, Prompt.assistant("ok") + + def _convert_extended_messages_to_provider( + self, messages: list[PromptMessageExtended] + ) -> list[PromptMessageExtended]: + return messages + + +@pytest.mark.asyncio +async def test_templates_sent_when_history_disabled(): + agent = LlmAgent(AgentConfig("test-agent")) + llm = FakeLLM() + agent._llm = llm + + # Seed a template baseline and make sure history mirrors it + template_result = PromptMessageExtended( + role="user", + content=[TextContent(type="text", text="template baseline")], + is_template=True, + ) + agent._message_history = [template_result.model_copy(deep=True)] + + user_msg = PromptMessageExtended( + role="user", content=[TextContent(type="text", text="hello world")] + ) + + response = await agent.generate_impl([user_msg], RequestParams(use_history=False)) + + assert llm.last_messages is not None + assert llm.last_messages[0].first_text() == template_result.first_text() + # History not extended when use_history is False (template remains) + assert len(agent.message_history) == 1 + assert agent.message_history[0].first_text() == template_result.first_text() + assert response.role == "assistant" diff --git a/tests/unit/fast_agent/agents/test_llm_content_filter.py b/tests/unit/fast_agent/agents/test_llm_content_filter.py index f1c8e64ad..f321e9cc3 100644 --- a/tests/unit/fast_agent/agents/test_llm_content_filter.py +++ b/tests/unit/fast_agent/agents/test_llm_content_filter.py @@ -16,23 +16,33 @@ FAST_AGENT_ERROR_CHANNEL, FAST_AGENT_REMOVED_METADATA_CHANNEL, ) +from fast_agent.interfaces import FastAgentLLMProtocol from fast_agent.llm.provider_types import Provider from fast_agent.types import PromptMessageExtended, text_content -class RecordingStubLLM: +class RecordingStubLLM(FastAgentLLMProtocol): """Minimal FastAgentLLMProtocol implementation for testing.""" def __init__(self, model_name: str = "passthrough") -> None: - self.model_name = model_name - self.provider = Provider.FAST_AGENT + self._model_name = model_name + self._provider = Provider.FAST_AGENT self.generated_messages: list[PromptMessageExtended] | None = None self._message_history: list[PromptMessageExtended] = [] - self.usage_accumulator = None + + # self.usage_accumulator = None + + @property + def model_name(self) -> str | None: + return self._model_name + + @property + def provider(self) -> Provider: + return self._provider async def generate(self, messages, request_params=None, tools=None): self.generated_messages = messages - self._message_history.extend(messages) + self._message_history = messages return PromptMessageExtended( role="assistant", content=[TextContent(type="text", text="ok")], @@ -166,9 +176,7 @@ async def test_metadata_clears_when_supported_content_only(): channels = (stub.generated_messages or [])[0].channels or {} assert FAST_AGENT_REMOVED_METADATA_CHANNEL in channels - second_message = PromptMessageExtended( - role="user", content=[text_content("Next turn")] - ) + second_message = PromptMessageExtended(role="user", content=[text_content("Next turn")]) await decorator.generate_impl([second_message]) assert stub.generated_messages is not None diff --git a/tests/unit/fast_agent/llm/provider/anthropic/test_anthropic_cache_control.py b/tests/unit/fast_agent/llm/provider/anthropic/test_anthropic_cache_control.py new file mode 100644 index 000000000..efdfe3242 --- /dev/null +++ b/tests/unit/fast_agent/llm/provider/anthropic/test_anthropic_cache_control.py @@ -0,0 +1,77 @@ +from mcp.types import TextContent + +from fast_agent.llm.provider.anthropic.cache_planner import AnthropicCachePlanner +from fast_agent.llm.provider.anthropic.llm_anthropic import AnthropicLLM +from fast_agent.llm.provider.anthropic.multipart_converter_anthropic import AnthropicConverter +from fast_agent.mcp.prompt_message_extended import PromptMessageExtended + + +def make_message(text: str, *, is_template: bool = False) -> PromptMessageExtended: + return PromptMessageExtended( + role="user", content=[TextContent(type="text", text=text)], is_template=is_template + ) + + +def count_cache_controls(messages: list[dict]) -> int: + return sum( + 1 + for msg in messages + for block in msg.get("content", []) + if isinstance(block, dict) and block.get("cache_control") + ) + + +def test_template_cache_respects_budget(): + planner = AnthropicCachePlanner(max_total_blocks=4) + extended = [ + make_message("template 1", is_template=True), + make_message("template 2", is_template=True), + make_message("user turn"), + ] + + plan_indices = planner.plan_indices(extended, cache_mode="prompt", system_cache_blocks=0) + provider_msgs = [AnthropicConverter.convert_to_anthropic(msg) for msg in extended] + + for idx in plan_indices: + AnthropicLLM._apply_cache_control_to_message(provider_msgs[idx]) + + assert "cache_control" in provider_msgs[0]["content"][-1] + assert "cache_control" in provider_msgs[1]["content"][-1] + + +def test_conversation_cache_respects_four_block_limit(): + planner = AnthropicCachePlanner(max_total_blocks=4) + system_cache_blocks = 1 + extended = [ + make_message("template 1", is_template=True), + make_message("template 2", is_template=True), + ] + extended.extend(make_message(f"turn {i}") for i in range(6)) + + plan_indices = planner.plan_indices(extended, cache_mode="auto", system_cache_blocks=system_cache_blocks) + provider_msgs = [AnthropicConverter.convert_to_anthropic(msg) for msg in extended] + for idx in plan_indices: + AnthropicLLM._apply_cache_control_to_message(provider_msgs[idx]) + + total_cache_blocks = system_cache_blocks + count_cache_controls(provider_msgs) + + assert total_cache_blocks <= 4 + assert len([i for i in plan_indices if i >= 2]) <= 1 # system + templates leave one slot + + +def test_conversation_cache_waits_for_walk_distance(): + planner = AnthropicCachePlanner(max_total_blocks=4) + extended = [ + make_message("template", is_template=True), + make_message("user 1"), + make_message("assistant 1"), + ] + + plan_indices = planner.plan_indices(extended, cache_mode="auto", system_cache_blocks=0) + provider_msgs = [AnthropicConverter.convert_to_anthropic(msg) for msg in extended] + + assert plan_indices == [0] + for idx in plan_indices: + AnthropicLLM._apply_cache_control_to_message(provider_msgs[idx]) + + assert count_cache_controls(provider_msgs) == 1 diff --git a/tests/unit/fast_agent/llm/provider/anthropic/test_tool_id_sanitization.py b/tests/unit/fast_agent/llm/provider/anthropic/test_tool_id_sanitization.py new file mode 100644 index 000000000..1ffade622 --- /dev/null +++ b/tests/unit/fast_agent/llm/provider/anthropic/test_tool_id_sanitization.py @@ -0,0 +1,36 @@ +from typing import TYPE_CHECKING + +from mcp.types import CallToolRequest, CallToolRequestParams, CallToolResult, TextContent + +from fast_agent.llm.provider.anthropic.multipart_converter_anthropic import AnthropicConverter +from fast_agent.types import PromptMessageExtended + +if TYPE_CHECKING: + from anthropic.types import MessageParam + + +def test_sanitizes_tool_use_ids_for_assistant_calls(): + dirty_id = "functions.fetch_magic_string:0" + expected = "functions_fetch_magic_string_0" + params = CallToolRequestParams(name="fetch_magic_string", arguments={}) + req = CallToolRequest(params=params) + + msg = PromptMessageExtended(role="assistant", content=[], tool_calls={dirty_id: req}) + + converted: MessageParam = AnthropicConverter.convert_to_anthropic(msg) + + assert converted["role"] == "assistant" + assert converted["content"][0]["id"] == expected + + +def test_sanitizes_tool_use_ids_for_tool_results(): + dirty_id = "functions.fetch_magic_string:0" + expected = "functions_fetch_magic_string_0" + result = CallToolResult(content=[TextContent(type="text", text="done")], isError=False) + + msg = PromptMessageExtended(role="user", content=[], tool_results={dirty_id: result}) + + converted: MessageParam = AnthropicConverter.convert_to_anthropic(msg) + + assert converted["role"] == "user" + assert converted["content"][0]["tool_use_id"] == expected diff --git a/tests/unit/fast_agent/llm/providers/test_augmented_llm_anthropic_caching.py b/tests/unit/fast_agent/llm/providers/test_augmented_llm_anthropic_caching.py deleted file mode 100644 index 7432844cf..000000000 --- a/tests/unit/fast_agent/llm/providers/test_augmented_llm_anthropic_caching.py +++ /dev/null @@ -1,403 +0,0 @@ -import unittest -from unittest.mock import AsyncMock, MagicMock, patch - -from mcp.types import TextContent - -from fast_agent.config import AnthropicSettings, Settings -from fast_agent.llm.provider.anthropic.llm_anthropic import AnthropicLLM -from fast_agent.mcp.prompt_message_extended import PromptMessageExtended - - -class TestAnthropicCaching(unittest.IsolatedAsyncioTestCase): - """Test cases for Anthropic caching functionality.""" - - def setUp(self): - """Set up test environment.""" - self.mock_context = MagicMock() - self.mock_context.config = Settings() - self.mock_aggregator = AsyncMock() - self.mock_aggregator.list_tools = AsyncMock( - return_value=MagicMock( - tools=[ - MagicMock( - name="test_tool", - description="Test tool", - inputSchema={"type": "object", "properties": {}}, - ) - ] - ) - ) - - def _create_llm(self, cache_mode: str = "off") -> AnthropicLLM: - """Create an AnthropicLLM instance with specified cache mode.""" - self.mock_context.config.anthropic = AnthropicSettings( - api_key="test_key", cache_mode=cache_mode - ) - - llm = AnthropicLLM(context=self.mock_context, aggregator=self.mock_aggregator) - return llm - - @patch("fast_agent.llm.provider.anthropic.llm_anthropic.AsyncAnthropic") - async def test_caching_off_mode(self, mock_anthropic_class): - """Test that no caching is applied when cache_mode is 'off'.""" - llm = self._create_llm(cache_mode="off") - llm.instruction = "Test system prompt" - - # Capture the arguments passed to the streaming API - captured_args = None - - # Mock the Anthropic client - mock_client = MagicMock() - mock_anthropic_class.return_value = mock_client - - # Create a proper async context manager for the stream - class MockStream: - async def __aenter__(self): - return self - - async def __aexit__(self, exc_type, exc, tb): - return None - - def __aiter__(self): - return iter([]) - - # Capture arguments and return the mock stream - def stream_method(**kwargs): - nonlocal captured_args - captured_args = kwargs - return MockStream() - - mock_client.messages.stream = stream_method - - # Mock the _process_stream method to return a response - # Create a usage mock that won't trigger warnings - mock_usage = MagicMock() - mock_usage.input_tokens = 100 - mock_usage.output_tokens = 50 - mock_usage.cache_creation_input_tokens = None - mock_usage.cache_read_input_tokens = None - mock_usage.trafficType = None # Add trafficType to prevent Google genai warning - - mock_response = MagicMock( - content=[MagicMock(type="text", text="Test response")], - stop_reason="end_turn", - usage=mock_usage, - ) - llm._process_stream = AsyncMock(return_value=mock_response) - - # Create a test message - message_param = {"role": "user", "content": [{"type": "text", "text": "Test message"}]} - - # Run the completion - await llm._anthropic_completion(message_param) - - # Verify arguments were captured - self.assertIsNotNone(captured_args) - - # Check that system prompt exists but has no cache_control - system = captured_args.get("system") - self.assertIsNotNone(system) - - # When cache_mode is "off", system should remain a string - self.assertIsInstance(system, str) - self.assertEqual(system, "Test system prompt") - - @patch("fast_agent.llm.provider.anthropic.llm_anthropic.AsyncAnthropic") - async def test_caching_prompt_mode(self, mock_anthropic_class): - """Test caching behavior in 'prompt' mode.""" - llm = self._create_llm(cache_mode="prompt") - llm.instruction = "Test system prompt" - - # Capture the arguments passed to the streaming API - captured_args = None - - # Mock the Anthropic client - mock_client = MagicMock() - mock_anthropic_class.return_value = mock_client - - # Create a proper async context manager for the stream - class MockStream: - async def __aenter__(self): - return self - - async def __aexit__(self, exc_type, exc, tb): - return None - - def __aiter__(self): - return iter([]) - - # Capture arguments and return the mock stream - def stream_method(**kwargs): - nonlocal captured_args - captured_args = kwargs - return MockStream() - - mock_client.messages.stream = stream_method - - # Mock the _process_stream method to return a response - # Create a usage mock that won't trigger warnings - mock_usage = MagicMock() - mock_usage.input_tokens = 100 - mock_usage.output_tokens = 50 - mock_usage.cache_creation_input_tokens = None - mock_usage.cache_read_input_tokens = None - mock_usage.trafficType = None # Add trafficType to prevent Google genai warning - - mock_response = MagicMock( - content=[MagicMock(type="text", text="Test response")], - stop_reason="end_turn", - usage=mock_usage, - ) - llm._process_stream = AsyncMock(return_value=mock_response) - - # Create a test message - message_param = {"role": "user", "content": [{"type": "text", "text": "Test message"}]} - - # Run the completion - await llm._anthropic_completion(message_param) - - # Verify arguments were captured - self.assertIsNotNone(captured_args) - - # Check that system prompt has cache_control when cache_mode is "prompt" - system = captured_args.get("system") - self.assertIsNotNone(system) - - # When cache_mode is "prompt", system should be converted to a list with cache_control - self.assertIsInstance(system, list) - self.assertEqual(len(system), 1) - self.assertEqual(system[0]["type"], "text") - self.assertEqual(system[0]["text"], "Test system prompt") - self.assertIn("cache_control", system[0]) - self.assertEqual(system[0]["cache_control"]["type"], "ephemeral") - - # Note: According to the code comment, tools and system are cached together - # via the system prompt, so tools themselves don't get cache_control - - @patch("fast_agent.llm.provider.anthropic.llm_anthropic.AsyncAnthropic") - async def test_caching_auto_mode(self, mock_anthropic_class): - """Test caching behavior in 'auto' mode.""" - llm = self._create_llm(cache_mode="auto") - llm.instruction = "Test system prompt" - - # Add some messages to history to test message caching - llm.history.extend( - [ - {"role": "user", "content": [{"type": "text", "text": "First message"}]}, - {"role": "assistant", "content": [{"type": "text", "text": "First response"}]}, - {"role": "user", "content": [{"type": "text", "text": "Second message"}]}, - ] - ) - - # Capture the arguments passed to the streaming API - captured_args = None - - # Mock the Anthropic client - mock_client = MagicMock() - mock_anthropic_class.return_value = mock_client - - # Create a proper async context manager for the stream - class MockStream: - async def __aenter__(self): - return self - - async def __aexit__(self, exc_type, exc, tb): - return None - - def __aiter__(self): - return iter([]) - - # Capture arguments and return the mock stream - def stream_method(**kwargs): - nonlocal captured_args - captured_args = kwargs - return MockStream() - - mock_client.messages.stream = stream_method - - # Mock the _process_stream method to return a response - # Create a usage mock that won't trigger warnings - mock_usage = MagicMock() - mock_usage.input_tokens = 100 - mock_usage.output_tokens = 50 - mock_usage.cache_creation_input_tokens = None - mock_usage.cache_read_input_tokens = None - mock_usage.trafficType = None # Add trafficType to prevent Google genai warning - - mock_response = MagicMock( - content=[MagicMock(type="text", text="Test response")], - stop_reason="end_turn", - usage=mock_usage, - ) - llm._process_stream = AsyncMock(return_value=mock_response) - - # Create a test message - message_param = {"role": "user", "content": [{"type": "text", "text": "Test message"}]} - - # Run the completion - await llm._anthropic_completion(message_param) - - # Verify arguments were captured - self.assertIsNotNone(captured_args) - - # Check that system prompt has cache_control when cache_mode is "auto" - system = captured_args.get("system") - self.assertIsNotNone(system) - - # When cache_mode is "auto", system should be converted to a list with cache_control - self.assertIsInstance(system, list) - self.assertEqual(len(system), 1) - self.assertEqual(system[0]["type"], "text") - self.assertEqual(system[0]["text"], "Test system prompt") - self.assertIn("cache_control", system[0]) - self.assertEqual(system[0]["cache_control"]["type"], "ephemeral") - - # In auto mode, conversation messages may have cache control if there are enough messages - messages = captured_args.get("messages", []) - self.assertGreater(len(messages), 0) - - # Verify we have the expected messages - # History has 3 messages + prompt messages (if any) + the new message - # Let's just verify we have messages and the structure is correct - self.assertGreaterEqual(len(messages), 4) # At least the history + new message - - @patch("fast_agent.llm.provider.anthropic.llm_anthropic.AsyncAnthropic") - async def test_template_caching_prompt_mode(self, mock_anthropic_class): - """Test that template messages are cached in 'prompt' mode.""" - llm = self._create_llm(cache_mode="prompt") - - # Mock the Anthropic client - mock_client = MagicMock() - mock_anthropic_class.return_value = mock_client - - # Create a proper async context manager for the stream - class MockStream: - async def __aenter__(self): - return self - - async def __aexit__(self, exc_type, exc, tb): - return None - - def __aiter__(self): - return iter([]) - - # Mock the stream method - mock_client.messages.stream = lambda **kwargs: MockStream() - - # Mock the _process_stream method to return a response - mock_usage = MagicMock() - mock_usage.input_tokens = 100 - mock_usage.output_tokens = 50 - mock_usage.cache_creation_input_tokens = None - mock_usage.cache_read_input_tokens = None - mock_usage.trafficType = None - - mock_response = MagicMock( - content=[MagicMock(type="text", text="Response")], - stop_reason="end_turn", - usage=mock_usage, - ) - llm._process_stream = AsyncMock(return_value=mock_response) - - # Create template messages - template_messages = [ - PromptMessageExtended( - role="user", content=[TextContent(type="text", text="Template message 1")] - ), - PromptMessageExtended( - role="assistant", content=[TextContent(type="text", text="Template response 1")] - ), - PromptMessageExtended( - role="user", content=[TextContent(type="text", text="Current question")] - ), - ] - - # Apply template with is_template=True - await llm._apply_prompt_provider_specific( - template_messages, request_params=None, tools=None, is_template=True - ) - - # Check that template messages in history have cache control - history_messages = llm.history.get(include_completion_history=False) - - # Verify that at least one template message has cache control - found_cache_control = False - for msg in history_messages: - if isinstance(msg, dict) and "content" in msg: - for block in msg["content"]: - if isinstance(block, dict) and "cache_control" in block: - found_cache_control = True - self.assertEqual(block["cache_control"]["type"], "ephemeral") - - self.assertTrue(found_cache_control, "No cache control found in template messages") - - @patch("fast_agent.llm.provider.anthropic.llm_anthropic.AsyncAnthropic") - async def test_template_caching_off_mode(self, mock_anthropic_class): - """Test that template messages are NOT cached in 'off' mode.""" - llm = self._create_llm(cache_mode="off") - - # Mock the Anthropic client - mock_client = MagicMock() - mock_anthropic_class.return_value = mock_client - - # Create a proper async context manager for the stream - class MockStream: - async def __aenter__(self): - return self - - async def __aexit__(self, exc_type, exc, tb): - return None - - def __aiter__(self): - return iter([]) - - # Mock the stream method - mock_client.messages.stream = lambda **kwargs: MockStream() - - # Mock the _process_stream method to return a response - mock_usage = MagicMock() - mock_usage.input_tokens = 100 - mock_usage.output_tokens = 50 - mock_usage.cache_creation_input_tokens = None - mock_usage.cache_read_input_tokens = None - mock_usage.trafficType = None - - mock_response = MagicMock( - content=[MagicMock(type="text", text="Response")], - stop_reason="end_turn", - usage=mock_usage, - ) - llm._process_stream = AsyncMock(return_value=mock_response) - - # Create template messages - template_messages = [ - PromptMessageExtended( - role="user", content=[TextContent(type="text", text="Template message")] - ), - PromptMessageExtended( - role="user", content=[TextContent(type="text", text="Current question")] - ), - ] - - # Apply template with is_template=True - await llm._apply_prompt_provider_specific( - template_messages, request_params=None, is_template=True - ) - - # Check that template messages in history do NOT have cache control - history_messages = llm.history.get(include_completion_history=False) - - # Verify that no template message has cache control - for msg in history_messages: - if isinstance(msg, dict) and "content" in msg: - for block in msg["content"]: - if isinstance(block, dict): - self.assertNotIn( - "cache_control", - block, - "Cache control found in template message when cache_mode is 'off'", - ) - - -if __name__ == "__main__": - unittest.main() diff --git a/tests/unit/fast_agent/llm/providers/test_llm_anthropic_caching.py b/tests/unit/fast_agent/llm/providers/test_llm_anthropic_caching.py new file mode 100644 index 000000000..fed369e89 --- /dev/null +++ b/tests/unit/fast_agent/llm/providers/test_llm_anthropic_caching.py @@ -0,0 +1,300 @@ +""" +Unit tests for Anthropic caching functionality. + +These tests directly test the _convert_extended_messages_to_provider method +to verify cache_control markers are applied correctly based on cache_mode settings. +""" + +import pytest +from mcp.types import CallToolResult, TextContent + +from fast_agent.config import AnthropicSettings, Settings +from fast_agent.context import Context +from fast_agent.llm.provider.anthropic.cache_planner import AnthropicCachePlanner +from fast_agent.llm.provider.anthropic.llm_anthropic import AnthropicLLM +from fast_agent.llm.provider.anthropic.multipart_converter_anthropic import AnthropicConverter +from fast_agent.mcp.prompt_message_extended import PromptMessageExtended +from fast_agent.types import RequestParams + + +class TestAnthropicCaching: + """Test cases for Anthropic caching functionality.""" + + def _create_context_with_cache_mode(self, cache_mode: str) -> Context: + """Create a context with specified cache mode.""" + ctx = Context() + ctx.config = Settings() + ctx.config.anthropic = AnthropicSettings( + api_key="test_key", cache_mode=cache_mode + ) + return ctx + + def _create_llm(self, cache_mode: str = "off") -> AnthropicLLM: + """Create an AnthropicLLM instance with specified cache mode.""" + ctx = self._create_context_with_cache_mode(cache_mode) + llm = AnthropicLLM(context=ctx) + return llm + + def _apply_cache_plan( + self, messages: list[PromptMessageExtended], cache_mode: str, system_blocks: int = 0 + ) -> list[dict]: + planner = AnthropicCachePlanner() + plan = planner.plan_indices(messages, cache_mode=cache_mode, system_cache_blocks=system_blocks) + converted = [AnthropicConverter.convert_to_anthropic(m) for m in messages] + for idx in plan: + AnthropicLLM._apply_cache_control_to_message(converted[idx]) + return converted + + def test_conversion_off_mode_no_cache_control(self): + """Test that no cache_control is applied when cache_mode is 'off'.""" + # Create test messages + messages = [ + PromptMessageExtended( + role="user", content=[TextContent(type="text", text="Hello")] + ), + PromptMessageExtended( + role="assistant", content=[TextContent(type="text", text="Hi there")] + ), + ] + + converted = self._apply_cache_plan(messages, cache_mode="off") + + # Verify no cache_control in any message + assert len(converted) == 2 + for msg in converted: + assert "content" in msg + for block in msg["content"]: + if isinstance(block, dict): + assert "cache_control" not in block, ( + "cache_control should not be present when cache_mode is 'off'" + ) + + def test_conversion_prompt_mode_templates_cached(self): + """Test that template messages get cache_control in 'prompt' mode.""" + # Create template + conversation messages (agent supplies all, flags templates) + template_msgs = [ + PromptMessageExtended( + role="user", content=[TextContent(type="text", text="System context")], is_template=True + ), + PromptMessageExtended( + role="assistant", content=[TextContent(type="text", text="Understood")], is_template=True + ), + ] + conversation_msgs = [ + PromptMessageExtended( + role="user", content=[TextContent(type="text", text="Question")] + ), + ] + + converted = self._apply_cache_plan(template_msgs + conversation_msgs, cache_mode="prompt") + + # Verify we have 3 messages (2 templates + 1 conversation) + assert len(converted) == 3 + + # Template messages should have cache_control + # The last template message should have cache_control on its last block + found_cache_control = False + for i, msg in enumerate(converted[:2]): # First 2 are templates + if "content" in msg: + for block in msg["content"]: + if isinstance(block, dict) and "cache_control" in block: + found_cache_control = True + assert block["cache_control"]["type"] == "ephemeral" + + assert found_cache_control, "Template messages should have cache_control in 'prompt' mode" + + # Conversation message should NOT have cache_control + conv_msg = converted[2] + for block in conv_msg.get("content", []): + if isinstance(block, dict): + assert "cache_control" not in block, ( + "Conversation messages should not have cache_control in 'prompt' mode" + ) + + def test_conversion_auto_mode_templates_cached(self): + """Test that template messages get cache_control in 'auto' mode.""" + template_msgs = [ + PromptMessageExtended( + role="user", content=[TextContent(type="text", text="Template")], is_template=True + ), + ] + conversation_msgs = [ + PromptMessageExtended( + role="user", content=[TextContent(type="text", text="Question")] + ), + ] + + converted = self._apply_cache_plan(template_msgs + conversation_msgs, cache_mode="auto") + + # Template message should have cache_control + found_cache_control = False + template_msg = converted[0] + if "content" in template_msg: + for block in template_msg["content"]: + if isinstance(block, dict) and "cache_control" in block: + found_cache_control = True + assert block["cache_control"]["type"] == "ephemeral" + + assert found_cache_control, "Template messages should have cache_control in 'auto' mode" + + def test_conversion_off_mode_templates_not_cached(self): + """Test that template messages do NOT get cache_control when cache_mode is 'off'.""" + template_msgs = [ + PromptMessageExtended( + role="user", content=[TextContent(type="text", text="Template")], is_template=True + ), + PromptMessageExtended( + role="assistant", content=[TextContent(type="text", text="Response")], is_template=True + ), + ] + conversation_msgs = [ + PromptMessageExtended( + role="user", content=[TextContent(type="text", text="Question")] + ), + ] + + converted = self._apply_cache_plan(template_msgs + conversation_msgs, cache_mode="off") + + # No messages should have cache_control + for msg in converted: + if "content" in msg: + for block in msg["content"]: + if isinstance(block, dict): + assert "cache_control" not in block, ( + "No messages should have cache_control when cache_mode is 'off'" + ) + + def test_conversion_multiple_messages_structure(self): + """Test that message structure is preserved during conversion.""" + messages = [ + PromptMessageExtended( + role="user", content=[TextContent(type="text", text="First")] + ), + PromptMessageExtended( + role="assistant", content=[TextContent(type="text", text="Second")] + ), + PromptMessageExtended( + role="user", content=[TextContent(type="text", text="Third")] + ), + ] + + converted = [AnthropicConverter.convert_to_anthropic(m) for m in messages] + + # Verify structure + assert len(converted) == 3 + assert converted[0]["role"] == "user" + assert converted[1]["role"] == "assistant" + assert converted[2]["role"] == "user" + + def test_build_request_messages_avoids_duplicate_tool_results(self): + """Ensure tool_result blocks are only included once per tool use.""" + llm = self._create_llm() + tool_id = "toolu_test" + tool_result = CallToolResult( + content=[TextContent(type="text", text="result payload")], isError=False + ) + user_msg = PromptMessageExtended(role="user", content=[], tool_results={tool_id: tool_result}) + history = [user_msg] + + params = llm.get_request_params(RequestParams(use_history=True)) + message_param = AnthropicConverter.convert_to_anthropic(user_msg) + + prepared = llm._build_request_messages(params, message_param, history=history) + + tool_blocks = [ + block + for msg in prepared + for block in msg.get("content", []) + if isinstance(block, dict) and block.get("type") == "tool_result" + ] + + assert len(tool_blocks) == 1 + assert tool_blocks[0]["tool_use_id"] == tool_id + + def test_build_request_messages_includes_current_when_history_empty(self): + """Fallback to the current message if history produced no entries.""" + llm = self._create_llm() + params = llm.get_request_params(RequestParams(use_history=True)) + msg = PromptMessageExtended(role="user", content=[TextContent(type="text", text="hi")]) + message_param = AnthropicConverter.convert_to_anthropic(msg) + + prepared = llm._build_request_messages(params, message_param, history=[]) + + assert prepared[-1] == message_param + + def test_build_request_messages_without_history(self): + """When history is disabled, always send the current message.""" + llm = self._create_llm() + params = llm.get_request_params(RequestParams(use_history=False)) + msg = PromptMessageExtended(role="user", content=[TextContent(type="text", text="hi")]) + message_param = AnthropicConverter.convert_to_anthropic(msg) + + prepared = llm._build_request_messages(params, message_param, history=[]) + + assert prepared == [message_param] + + def test_conversion_empty_messages(self): + """Test conversion of empty message list.""" + llm = self._create_llm(cache_mode="off") + + converted = llm._convert_extended_messages_to_provider([]) + + assert converted == [] + + def test_conversion_with_templates_only(self): + """Test conversion when only templates exist (no conversation).""" + # Create template messages + template_msgs = [ + PromptMessageExtended( + role="user", content=[TextContent(type="text", text="Template")], is_template=True + ), + ] + + converted = self._apply_cache_plan(template_msgs, cache_mode="prompt") + + # Should have just the template + assert len(converted) == 1 + + # Template should have cache_control + found_cache_control = False + for block in converted[0].get("content", []): + if isinstance(block, dict) and "cache_control" in block: + found_cache_control = True + + assert found_cache_control, "Template should have cache_control in 'prompt' mode" + + def test_cache_control_on_last_content_block(self): + """Test that cache_control is applied to the last content block of template messages.""" + # Create a template with multiple content blocks + template_msgs = [ + PromptMessageExtended( + role="user", + content=[ + TextContent(type="text", text="First block"), + TextContent(type="text", text="Second block"), + ], + is_template=True, + ), + ] + + converted = self._apply_cache_plan(template_msgs, cache_mode="prompt") + + # Cache control should be on the last block + content_blocks = converted[0]["content"] + assert len(content_blocks) == 2 + + # First block should NOT have cache_control + if isinstance(content_blocks[0], dict): + # Cache control might be on any block, but typically the last one + pass + + # At least one block should have cache_control + found_cache_control = any( + isinstance(block, dict) and "cache_control" in block + for block in content_blocks + ) + assert found_cache_control, "Template should have cache_control" + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) diff --git a/tests/unit/fast_agent/llm/providers/test_augmented_llm_azure.py b/tests/unit/fast_agent/llm/providers/test_llm_azure.py similarity index 100% rename from tests/unit/fast_agent/llm/providers/test_augmented_llm_azure.py rename to tests/unit/fast_agent/llm/providers/test_llm_azure.py diff --git a/tests/unit/fast_agent/llm/providers/test_llm_openai_history.py b/tests/unit/fast_agent/llm/providers/test_llm_openai_history.py new file mode 100644 index 000000000..e406dd46f --- /dev/null +++ b/tests/unit/fast_agent/llm/providers/test_llm_openai_history.py @@ -0,0 +1,68 @@ +import pytest +from mcp.types import CallToolRequest, CallToolRequestParams, CallToolResult, TextContent + +from fast_agent.context import Context +from fast_agent.core.prompt import Prompt +from fast_agent.llm.provider.openai.llm_openai import OpenAILLM +from fast_agent.llm.request_params import RequestParams +from fast_agent.types import PromptMessageExtended + + +class CapturingOpenAI(OpenAILLM): + def __init__(self, **kwargs): + super().__init__(**kwargs) + self.captured = None + + async def _openai_completion(self, message, request_params=None, tools=None): + self.captured = message + return Prompt.assistant("ok") + + +def _build_tool_messages(): + tool_call = CallToolRequest( + method="tools/call", + params=CallToolRequestParams(name="demo_tool", arguments={"arg": "value"}), + ) + assistant_tool_call = Prompt.assistant("calling tool", tool_calls={"call_1": tool_call}) + + tool_result_msg = PromptMessageExtended( + role="user", + content=[TextContent(type="text", text="tool response payload")], + tool_results={ + "call_1": CallToolResult( + content=[TextContent(type="text", text="result details")], + ) + }, + ) + return assistant_tool_call, tool_result_msg + + +@pytest.mark.asyncio +async def test_apply_prompt_avoids_duplicate_last_message_when_using_history(): + context = Context() + llm = CapturingOpenAI(context=context) + + assistant_tool_call, tool_result_msg = _build_tool_messages() + history = [assistant_tool_call, tool_result_msg] + + await llm._apply_prompt_provider_specific(history, None, None) + + assert isinstance(llm.captured, list) + assert llm.captured[0]["role"] == "assistant" + # Tool result conversion should follow the assistant tool_calls + assert any(msg.get("role") == "tool" for msg in llm.captured) + + +@pytest.mark.asyncio +async def test_apply_prompt_converts_last_message_when_history_disabled(): + context = Context() + llm = CapturingOpenAI(context=context) + + _, tool_result_msg = _build_tool_messages() + + await llm._apply_prompt_provider_specific( + [tool_result_msg], RequestParams(use_history=False), None + ) + + assert isinstance(llm.captured, list) + assert llm.captured # should send something to completion when history is off diff --git a/tests/unit/fast_agent/llm/providers/test_augmented_llm_tensorzero_unit.py b/tests/unit/fast_agent/llm/providers/test_llm_tensorzero_unit.py similarity index 100% rename from tests/unit/fast_agent/llm/providers/test_augmented_llm_tensorzero_unit.py rename to tests/unit/fast_agent/llm/providers/test_llm_tensorzero_unit.py diff --git a/tests/unit/fast_agent/llm/test_clear_behavior.py b/tests/unit/fast_agent/llm/test_clear_behavior.py index 7ad76aaf1..cf970100b 100644 --- a/tests/unit/fast_agent/llm/test_clear_behavior.py +++ b/tests/unit/fast_agent/llm/test_clear_behavior.py @@ -25,21 +25,21 @@ def _make_user_message(text: str) -> PromptMessageExtended: @pytest.mark.asyncio async def test_llm_clear_retains_templates(): ctx = Context() + agent = LlmAgent(config=AgentConfig(name="agent-under-test"), context=ctx) llm = PassthroughLLM(provider=Provider.FAST_AGENT, context=ctx) + agent._llm = llm - await llm.apply_prompt_template(_make_template_prompt("template context"), "demo") - assert [msg.first_text() for msg in llm.message_history] == ["template context"] + await agent.apply_prompt_template(_make_template_prompt("template context"), "demo") + assert [msg.first_text() for msg in agent.message_history] == ["template context"] - await llm.generate([_make_user_message("hello")]) - assert len(llm.message_history) >= 3 # template + user + assistant + await agent.generate(_make_user_message("hello")) + assert len(agent.message_history) >= 3 # template + user + assistant - llm.clear() - assert [msg.first_text() for msg in llm.message_history] == ["template context"] - assert len(llm.history.get()) == 1 + agent.clear() + assert [msg.first_text() for msg in agent.message_history] == ["template context"] - llm.clear(clear_prompts=True) - assert llm.message_history == [] - assert llm.history.get() == [] + agent.clear(clear_prompts=True) + assert agent.message_history == [] @pytest.mark.asyncio diff --git a/tests/unit/fast_agent/llm/test_passthrough.py b/tests/unit/fast_agent/llm/test_passthrough.py index 36e2b037c..8757442a2 100644 --- a/tests/unit/fast_agent/llm/test_passthrough.py +++ b/tests/unit/fast_agent/llm/test_passthrough.py @@ -89,6 +89,15 @@ async def test_generates_structured(): ) +@pytest.mark.asyncio +async def test_returns_assistant_message_verbatim(): + llm: FastAgentLLMProtocol = PassthroughLLM() + assistant_msg = Prompt.assistant("already answered") + result = await llm.generate([assistant_msg]) + assert result.role == "assistant" + assert result.first_text() == "already answered" + + @pytest.mark.asyncio async def test_usage_tracking(): """Test that PassthroughLLM correctly tracks usage""" diff --git a/tests/unit/fast_agent/llm/test_prepare_arguments.py b/tests/unit/fast_agent/llm/test_prepare_arguments.py index 2637250ea..287bacc6c 100644 --- a/tests/unit/fast_agent/llm/test_prepare_arguments.py +++ b/tests/unit/fast_agent/llm/test_prepare_arguments.py @@ -1,4 +1,4 @@ -from typing import List +from typing import Any, List from fast_agent.llm.fastagent_llm import FastAgentLLM from fast_agent.llm.provider.anthropic.llm_anthropic import AnthropicLLM @@ -24,6 +24,12 @@ async def _apply_prompt_provider_specific( """Implement the abstract method with minimal functionality""" return multipart_messages[-1] + def _convert_extended_messages_to_provider( + self, messages: List[PromptMessageExtended] + ) -> List[Any]: + """Convert messages to provider format - stub returns empty list""" + return [] + class TestRequestParamsInLLM: """Test suite for RequestParams handling in LLM classes"""