diff --git a/python/packages/foundry_hosting/agent_framework_foundry_hosting/_responses.py b/python/packages/foundry_hosting/agent_framework_foundry_hosting/_responses.py index 9395fe1c69..8940365930 100644 --- a/python/packages/foundry_hosting/agent_framework_foundry_hosting/_responses.py +++ b/python/packages/foundry_hosting/agent_framework_foundry_hosting/_responses.py @@ -10,10 +10,9 @@ import tempfile import threading from collections.abc import AsyncIterable, AsyncIterator, Generator, Sequence -from contextlib import suppress +from contextlib import AbstractAsyncContextManager, AsyncExitStack, suppress from dataclasses import asdict, is_dataclass from pathlib import Path -from contextlib import AbstractAsyncContextManager, AsyncExitStack, suppress from typing import Protocol, cast from agent_framework import ( @@ -73,6 +72,7 @@ MessageContentOutputTextContent, MessageContentReasoningTextContent, MessageContentRefusalContent, + MessageRole, OAuthConsentRequestOutputItem, OutputItem, OutputItemApplyPatchToolCall, @@ -117,6 +117,8 @@ logger = logging.getLogger(__name__) +_AZURE_RESPONSES_MESSAGE_ROLE_TYPE = f"{MessageRole.__module__}:{MessageRole.__qualname__}" + # region Approval Storage class ApprovalStorage(Protocol): @@ -250,7 +252,12 @@ def _checkpoint_storage_for_context(root: str, context_id: str) -> FileCheckpoin storage_path = (root_path / context_id).resolve() if not storage_path.is_relative_to(root_path): raise RuntimeError(f"Invalid checkpoint context id: {context_id!r}") - return FileCheckpointStorage(storage_path) + return FileCheckpointStorage( + storage_path, + # Keep this provider-specific allowlist narrow. Hosted workflow + # checkpoints can persist Azure's role enum inside Message objects. + allowed_checkpoint_types=[_AZURE_RESPONSES_MESSAGE_ROLE_TYPE], + ) # endregion Approval Storage diff --git a/python/packages/foundry_hosting/tests/test_responses.py b/python/packages/foundry_hosting/tests/test_responses.py index d5e25b99f9..f989fff7de 100644 --- a/python/packages/foundry_hosting/tests/test_responses.py +++ b/python/packages/foundry_hosting/tests/test_responses.py @@ -26,6 +26,9 @@ Message, RawAgent, ResponseStream, + WorkflowCheckpoint, + WorkflowCheckpointException, + WorkflowMessage, ) from azure.ai.agentserver.responses import InMemoryResponseProvider from mcp import McpError @@ -2712,6 +2715,23 @@ def _helper() -> Callable[[str, str], FileCheckpointStorage]: return _checkpoint_storage_for_context + @staticmethod + def _checkpoint_with_azure_message_role() -> WorkflowCheckpoint: + from azure.ai.agentserver.responses.models import MessageRole + + return WorkflowCheckpoint( + workflow_name="wf", + graph_signature_hash="hash", + messages={ + "executor": [ + WorkflowMessage( + data=Message(role=MessageRole.USER, contents=[Content.from_text("hello")]), + source_id="source", + ) + ] + }, + ) + def test_valid_segment_creates_storage_under_root(self, tmp_path: Any) -> None: helper = self._helper() root = tmp_path / "root" @@ -2720,6 +2740,65 @@ def test_valid_segment_creates_storage_under_root(self, tmp_path: Any) -> None: assert storage.storage_path.is_dir() assert storage.storage_path.parent == root.resolve() + async def test_storage_allows_azure_message_role_checkpoint_restore(self, tmp_path: Any) -> None: + from azure.ai.agentserver.responses.models import MessageRole + + helper = self._helper() + root = tmp_path / "root" + root.mkdir() + storage = helper(str(root), "resp_abc123") + checkpoint = self._checkpoint_with_azure_message_role() + + await storage.save(checkpoint) + loaded = await storage.load(checkpoint.checkpoint_id) + + loaded_message = loaded.messages["executor"][0].data + assert isinstance(loaded_message, Message) + assert type(loaded_message.role) is MessageRole + assert loaded_message.role == MessageRole.USER + assert loaded_message.text == "hello" + + async def test_plain_storage_blocks_azure_message_role_checkpoint_restore(self, tmp_path: Any) -> None: + storage = FileCheckpointStorage(tmp_path / "plain") + checkpoint = self._checkpoint_with_azure_message_role() + + await storage.save(checkpoint) + with pytest.raises(WorkflowCheckpointException, match="MessageRole"): + await storage.load(checkpoint.checkpoint_id) + + async def test_get_latest_restores_azure_message_role(self, tmp_path: Any) -> None: + from azure.ai.agentserver.responses.models import MessageRole + + helper = self._helper() + root = tmp_path / "root" + root.mkdir() + storage = helper(str(root), "resp_abc123") + checkpoint = self._checkpoint_with_azure_message_role() + + await storage.save(checkpoint) + latest = await storage.get_latest(workflow_name="wf") + + assert latest is not None + assert latest.checkpoint_id == checkpoint.checkpoint_id + latest_message = latest.messages["executor"][0].data + assert isinstance(latest_message, Message) + assert type(latest_message.role) is MessageRole + + async def test_get_latest_silently_skips_without_allowlist( + self, tmp_path: Any, caplog: pytest.LogCaptureFixture + ) -> None: + import logging + + storage = FileCheckpointStorage(tmp_path / "plain") + checkpoint = self._checkpoint_with_azure_message_role() + + await storage.save(checkpoint) + with caplog.at_level(logging.WARNING, logger="agent_framework"): + latest = await storage.get_latest(workflow_name="wf") + + assert latest is None + assert any("MessageRole" in message for message in caplog.messages) + @pytest.mark.parametrize( "bad_id", [ @@ -2923,6 +3002,8 @@ async def test_malicious_context_id_rejected_e2e(self, tmp_path: Any, context_fi f"before={before} after={after}" ) assert list(root.iterdir()) == [], f"Checkpoint directory created inside root for {context_field}={bad_id!r}" + + # region Agent lifecycle (lazy entry & OAuth consent surfacing)