Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 0 additions & 27 deletions agent.py

This file was deleted.

134 changes: 125 additions & 9 deletions src/fast_agent/agents/llm_decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,11 @@
from pydantic import BaseModel

from fast_agent.agents.agent_types import AgentConfig, AgentType
from fast_agent.constants import FAST_AGENT_ERROR_CHANNEL, FAST_AGENT_REMOVED_METADATA_CHANNEL
from fast_agent.constants import (
CONTROL_MESSAGE_SAVE_HISTORY,
FAST_AGENT_ERROR_CHANNEL,
FAST_AGENT_REMOVED_METADATA_CHANNEL,
)
from fast_agent.context import Context
from fast_agent.core.logging.logger import get_logger
from fast_agent.interfaces import (
Expand Down Expand Up @@ -127,6 +131,17 @@ class RemovedContentSummary:
message: str


@dataclass
class _CallContext:
"""Internal helper for assembling an LLM call."""

full_history: List[PromptMessageExtended]
call_params: RequestParams | None
persist_history: bool
sanitized_messages: List[PromptMessageExtended]
summary: RemovedContentSummary | None


class LlmDecorator(StreamingAgentMixin, AgentProtocol):
"""
A pure delegation wrapper around LlmAgent instances.
Expand All @@ -150,6 +165,9 @@ def __init__(
self._tracer = trace.get_tracer(__name__)
self.instruction = self.config.instruction

# Agent-owned conversation state (PromptMessageExtended only)
self._message_history: List[PromptMessageExtended] = []

# Store the default request params from config
self._default_request_params = self.config.default_request_params

Expand Down Expand Up @@ -338,7 +356,16 @@ async def apply_prompt_template(self, prompt_result: GetPromptResult, prompt_nam
Returns:
String representation of the assistant's response if generated
"""
from fast_agent.types import PromptMessageExtended

assert self._llm

multipart_messages = PromptMessageExtended.parse_get_prompt_result(prompt_result)
for msg in multipart_messages:
msg.is_template = True

self._message_history = [msg.model_copy(deep=True) for msg in multipart_messages]

return await self._llm.apply_prompt_template(prompt_result, prompt_name)

async def apply_prompt(
Expand Down Expand Up @@ -375,6 +402,11 @@ def clear(self, *, clear_prompts: bool = False) -> None:
if not self._llm:
return
self._llm.clear(clear_prompts=clear_prompts)
if clear_prompts:
self._message_history = []
else:
template_prefix = self._template_prefix_messages()
self._message_history = [msg.model_copy(deep=True) for msg in template_prefix]

async def structured(
self,
Expand Down Expand Up @@ -445,9 +477,16 @@ async def _generate_with_summary(
tools: List[Tool] | None = None,
) -> Tuple[PromptMessageExtended, RemovedContentSummary | None]:
assert self._llm, "LLM is not attached"
sanitized_messages, summary = self._sanitize_messages_for_llm(messages)
response = await self._llm.generate(sanitized_messages, request_params, tools)
return response, summary
call_ctx = self._prepare_llm_call(messages, request_params)

response = await self._llm.generate(
call_ctx.full_history, call_ctx.call_params, tools
)

if call_ctx.persist_history:
self._persist_history(call_ctx.sanitized_messages, response)

return response, call_ctx.summary

async def _structured_with_summary(
self,
Expand All @@ -456,9 +495,68 @@ async def _structured_with_summary(
request_params: RequestParams | None = None,
) -> Tuple[Tuple[ModelT | None, PromptMessageExtended], RemovedContentSummary | None]:
assert self._llm, "LLM is not attached"
call_ctx = self._prepare_llm_call(messages, request_params)

structured_result = await self._llm.structured(
call_ctx.full_history, model, call_ctx.call_params
)

if call_ctx.persist_history:
try:
_, assistant_message = structured_result
self._persist_history(call_ctx.sanitized_messages, assistant_message)
except Exception:
pass
return structured_result, call_ctx.summary

def _prepare_llm_call(
self, messages: List[PromptMessageExtended], request_params: RequestParams | None = None
) -> _CallContext:
"""Normalize template/history handling for both generate and structured."""
sanitized_messages, summary = self._sanitize_messages_for_llm(messages)
structured_result = await self._llm.structured(sanitized_messages, model, request_params)
return structured_result, summary
final_request_params = self._llm.get_request_params(request_params)

use_history = final_request_params.use_history if final_request_params else True
call_params = final_request_params.model_copy() if final_request_params else None
if call_params and not call_params.use_history:
call_params.use_history = True

base_history = self._message_history if use_history else self._template_prefix_messages()
full_history = [msg.model_copy(deep=True) for msg in base_history]
full_history.extend(sanitized_messages)

return _CallContext(
full_history=full_history,
call_params=call_params,
persist_history=use_history,
sanitized_messages=sanitized_messages,
summary=summary,
)

def _persist_history(
self,
sanitized_messages: List[PromptMessageExtended],
assistant_message: PromptMessageExtended,
) -> None:
"""Persist the last turn unless explicitly disabled by control text."""
if not sanitized_messages:
return
if sanitized_messages[-1].first_text().startswith(CONTROL_MESSAGE_SAVE_HISTORY):
return

history_messages = [self._strip_removed_metadata(msg) for msg in sanitized_messages]
self._message_history.extend(history_messages)
self._message_history.append(assistant_message)

@staticmethod
def _strip_removed_metadata(message: PromptMessageExtended) -> PromptMessageExtended:
"""Remove per-turn removed-content metadata before persisting to history."""
msg_copy = message.model_copy(deep=True)
if msg_copy.channels and FAST_AGENT_REMOVED_METADATA_CHANNEL in msg_copy.channels:
channels = dict(msg_copy.channels)
channels.pop(FAST_AGENT_REMOVED_METADATA_CHANNEL, None)
msg_copy.channels = channels if channels else None
return msg_copy

def _sanitize_messages_for_llm(
self, messages: List[PromptMessageExtended]
Expand Down Expand Up @@ -761,9 +859,27 @@ def message_history(self) -> List[PromptMessageExtended]:
Returns:
List of PromptMessageExtended objects representing the conversation history
"""
if self._llm:
return self._llm.message_history
return []
return self._message_history

@property
def template_messages(self) -> List[PromptMessageExtended]:
"""
Return the template prefix of the message history.

Templates are identified via the is_template flag and are expected to
appear as a contiguous prefix of the history.
"""
return [msg.model_copy(deep=True) for msg in self._template_prefix_messages()]

def _template_prefix_messages(self) -> List[PromptMessageExtended]:
"""Return the leading messages marked as templates (non-copy)."""
prefix: List[PromptMessageExtended] = []
for msg in self._message_history:
if msg.is_template:
prefix.append(msg)
else:
break
return prefix

def pop_last_message(self) -> PromptMessageExtended | None:
"""Remove and return the most recent message from the conversation history."""
Expand Down
5 changes: 2 additions & 3 deletions src/fast_agent/agents/mcp_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -1320,9 +1320,8 @@ def message_history(self) -> List[PromptMessageExtended]:
Returns:
List of PromptMessageExtended objects representing the conversation history
"""
if self._llm:
return self._llm.message_history
return []
# Conversation history is maintained at the agent layer; LLM history is diagnostic only.
return super().message_history

@property
def usage_accumulator(self) -> Optional["UsageAccumulator"]:
Expand Down
2 changes: 2 additions & 0 deletions src/fast_agent/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,3 +31,5 @@
{{env}}

The current date is {{currentDate}}."""

CONTROL_MESSAGE_SAVE_HISTORY = "***SAVE_HISTORY"
Loading
Loading