Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions python/packages/azurefunctions/tests/test_func_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,7 @@ def test_roundtrip_agent_executor_response(self) -> None:
original = AgentExecutorResponse(
executor_id="test_exec",
agent_response=AgentResponse(messages=[Message(role="assistant", text="Reply")]),
full_conversation=[Message(role="assistant", text="Reply")],
)
encoded = serialize_value(original)
decoded = deserialize_value(encoded)
Expand Down
5 changes: 5 additions & 0 deletions python/packages/azurefunctions/tests/test_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,7 @@ def test_extract_from_agent_executor_response_with_text(self) -> None:
response = AgentExecutorResponse(
executor_id="exec",
agent_response=AgentResponse(messages=[Message(role="assistant", text="Response text")]),
full_conversation=[Message(role="assistant", text="Response text")],
)

result = _extract_message_content(response)
Expand All @@ -228,6 +229,10 @@ def test_extract_from_agent_executor_response_with_messages(self) -> None:
Message(role="assistant", text="Last message"),
]
),
full_conversation=[
Message(role="user", text="First"),
Message(role="assistant", text="Last message"),
],
)

result = _extract_message_content(response)
Expand Down
85 changes: 52 additions & 33 deletions python/packages/core/agent_framework/_workflows/_agent_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import sys
from collections.abc import Awaitable, Callable, Mapping
from dataclasses import dataclass
from typing import Any, cast
from typing import Any, Literal, cast

from typing_extensions import Never

Expand Down Expand Up @@ -57,7 +57,7 @@ class AgentExecutorResponse:

executor_id: str
agent_response: AgentResponse
full_conversation: list[Message] | None = None
full_conversation: list[Message]


class AgentExecutor(Executor):
Expand All @@ -83,13 +83,25 @@ def __init__(
*,
session: AgentSession | None = None,
id: str | None = None,
context_mode: Literal["full", "last_agent", "custom"] | None = None,
context_filter: Callable[[list[Message]], list[Message]] | None = None,
):
"""Initialize the executor with a unique identifier.

Args:
agent: The agent to be wrapped by this executor.
session: The session to use for running the agent. If None, a new session will be created.
id: A unique identifier for the executor. If None, the agent's name will be used if available.
context_mode: Configuration for how the executor should manage conversation context upon
receiving an AgentExecutorResponse as input. Options:
- "full": append the full conversation (all prior messages + latest agent response) to the
cache for the agent run. This is the default mode.
- "last_agent": provide only the messages from the latest agent response as context for
the agent run.
- "custom": use the provided context_filter function to determine which messages to include
as context for the agent run.
context_filter: An optional function for filtering conversation context when context_mode is set
to "custom".
"""
# Prefer provided id; else use agent.name if present; else generate deterministic prefix
exec_id = id or resolve_agent_id(agent)
Expand All @@ -107,6 +119,14 @@ def __init__(
# This tracks the full conversation after each run
self._full_conversation: list[Message] = []

# Context mode validation
self._context_mode = context_mode or "full"
self._context_filter = context_filter
if self._context_mode not in {"full", "last_agent", "custom"}:
raise ValueError("context_mode must be one of 'full', 'last_agent', or 'custom'.")
if self._context_mode == "custom" and not self._context_filter:
raise ValueError("context_filter must be provided when context_mode is set to 'custom'.")

@property
def agent(self) -> SupportsAgentRun:
"""Get the underlying agent wrapped by this executor."""
Expand All @@ -129,6 +149,7 @@ async def run(
run the agent and emit an AgentExecutorResponse downstream.
"""
self._cache.extend(request.messages)

if request.should_respond:
await self._run_agent_and_emit(ctx)

Expand All @@ -143,19 +164,27 @@ async def from_response(
Strategy: treat the prior response's messages as the conversation state and
immediately run the agent to produce a new response.
"""
# Replace cache with full conversation if available, else fall back to agent_response messages.
source_messages = (
prior.full_conversation if prior.full_conversation is not None else prior.agent_response.messages
)
self._cache = list(source_messages)
if self._context_mode == "full":
self._cache.extend(prior.full_conversation)
elif self._context_mode == "last_agent":
self._cache.extend(prior.agent_response.messages)
else:
if not self._context_filter:
# This should never happen due to validation in __init__, but mypy doesn't track that well
raise ValueError("context_filter function must be provided for 'custom' context_mode.")
self._cache.extend(self._context_filter(prior.full_conversation))

await self._run_agent_and_emit(ctx)

@handler
async def from_str(
self, text: str, ctx: WorkflowContext[AgentExecutorResponse, AgentResponse | AgentResponseUpdate]
) -> None:
"""Accept a raw user prompt string and run the agent (one-shot)."""
self._cache = normalize_messages_input(text)
"""Accept a raw user prompt string and run the agent.

The new string input will be added to the cache which is used as the conversation context for the agent run.
"""
self._cache.extend(normalize_messages_input(text))
await self._run_agent_and_emit(ctx)

@handler
Expand All @@ -164,8 +193,11 @@ async def from_message(
message: Message,
ctx: WorkflowContext[AgentExecutorResponse, AgentResponse | AgentResponseUpdate],
) -> None:
"""Accept a single Message as input."""
self._cache = normalize_messages_input(message)
"""Accept a single Message as input.

The new message will be added to the cache which is used as the conversation context for the agent run.
"""
self._cache.extend(normalize_messages_input(message))
await self._run_agent_and_emit(ctx)

@handler
Expand All @@ -174,8 +206,11 @@ async def from_messages(
messages: list[str | Message],
ctx: WorkflowContext[AgentExecutorResponse, AgentResponse | AgentResponseUpdate],
) -> None:
"""Accept a list of chat inputs (strings or Message) as conversation context."""
self._cache = normalize_messages_input(messages)
"""Accept a list of chat inputs (strings or Message) as conversation context.

The new messages will be added to the cache which is used as the conversation context for the agent run.
"""
self._cache.extend(normalize_messages_input(messages))
await self._run_agent_and_emit(ctx)

@response_handler
Expand Down Expand Up @@ -249,24 +284,10 @@ async def on_checkpoint_restore(self, state: dict[str, Any]) -> None:
state: Checkpoint data dict
"""
cache_payload = state.get("cache")
if cache_payload:
try:
self._cache = cache_payload
except Exception as exc:
logger.warning("Failed to restore cache: %s", exc)
self._cache = []
else:
self._cache = []
self._cache = cache_payload or []

full_conversation_payload = state.get("full_conversation")
if full_conversation_payload:
try:
self._full_conversation = full_conversation_payload
except Exception as exc:
logger.warning("Failed to restore full conversation: %s", exc)
self._full_conversation = []
else:
self._full_conversation = []
self._full_conversation = full_conversation_payload or []

session_payload = state.get("agent_session")
if session_payload:
Expand All @@ -279,12 +300,10 @@ async def on_checkpoint_restore(self, state: dict[str, Any]) -> None:
self._session = self._agent.create_session()

pending_requests_payload = state.get("pending_agent_requests")
if pending_requests_payload:
self._pending_agent_requests = pending_requests_payload
self._pending_agent_requests = pending_requests_payload or {}

pending_responses_payload = state.get("pending_responses_to_agent")
if pending_responses_payload:
self._pending_responses_to_agent = pending_responses_payload
self._pending_responses_to_agent = pending_responses_payload or []

def reset(self) -> None:
"""Reset the internal cache of the executor."""
Expand Down
Loading
Loading