diff --git a/python/packages/core/agent_framework/_agents.py b/python/packages/core/agent_framework/_agents.py index 585898ae52..bfcd89946c 100644 --- a/python/packages/core/agent_framework/_agents.py +++ b/python/packages/core/agent_framework/_agents.py @@ -92,12 +92,16 @@ def _merge_options(base: dict[str, Any], override: dict[str, Any]) -> dict[str, Any]: """Merge two options dicts, with override values taking precedence. + ``None`` is treated as "unset": ``None`` overrides are skipped so they don't clobber a base + value, and the merged result is stripped of any remaining ``None`` values in a final pass so + unset options are never forwarded (e.g. an unset ``store`` is left for the service to default). + Args: base: The base options dict. override: The override options dict (values take precedence). Returns: - A new merged options dict. + A new merged options dict containing no ``None`` values. """ result = dict(base) @@ -123,7 +127,7 @@ def _merge_options(base: dict[str, Any], override: dict[str, Any]) -> dict[str, result["instructions"] = f"{result['instructions']}\n{value}" else: result[key] = value - return result + return {key: value for key, value in result.items() if value is not None} def _sanitize_agent_name(agent_name: str | None) -> str | None: @@ -460,6 +464,9 @@ async def _run_after_providers( if provider_session is None and self.context_providers: provider_session = AgentSession() + # When per-service-call persistence is enabled, the per-service-call middleware owns + # HistoryProvider persistence (in both the local and service-managed cases), so skip + # them on the once-per-run path to avoid double persistence. per_service_call_history_required = self.require_per_service_call_history_persistence and any( isinstance(provider, HistoryProvider) for provider in self.context_providers ) @@ -686,11 +693,14 @@ def __init__( description: A brief description of the agent's purpose. context_providers: Context providers to include during agent invocation. middleware: List of middleware to intercept agent and function invocations. - require_per_service_call_history_persistence: When True, history providers are invoked - around each model call instead of once per ``run()`` when the service - is not already storing history. If service-side storage is active for - the run, the agent skips local history providers and relies on the - service-managed conversation instead. + require_per_service_call_history_persistence: When True (and a HistoryProvider is + present), the provider always persists history via per-service-call middleware, + regardless of whether the client stores history server-side. If the client does + not store history, the middleware also loads providers around each model call and + drives the function loop with a local conversation; if it does, loading is skipped + (the service-managed conversation is the source of truth) and the middleware only + persists. A warning is logged for providers with ``load_messages=True`` when + loading is skipped because service-side storage is active. default_options: A TypedDict containing chat options. When using a typed agent like ``Agent[OpenAIChatOptions]``, this enables IDE autocomplete for provider-specific options including temperature, max_tokens, model, @@ -791,22 +801,20 @@ def _resolve_per_service_call_history_providers( self, *, session: AgentSession | None, - options: Mapping[str, Any] | None, + conversation_id: str | None, service_stores_history: bool, ) -> list[HistoryProvider]: history_providers = self._get_history_providers() if not self.require_per_service_call_history_persistence or not history_providers: return [] - conversation_id = ( - session.service_session_id - if session and session.service_session_id - else cast(str | None, (options or {}).get("conversation_id") or self.default_options.get("conversation_id")) - ) - if service_stores_history: - return [] - - if conversation_id is not None: + # A live service-managed session id takes precedence over the resolved conversation id. + if session and session.service_session_id: + conversation_id = session.service_session_id + # Without service-side storage the middleware persists locally and drives the function + # loop with a local sentinel, which cannot be reconciled with an existing service-managed + # conversation. When the service stores history, an existing conversation id is expected. + if conversation_id is not None and not service_stores_history: raise AgentInvalidRequestException( "require_per_service_call_history_persistence cannot be used " "with an existing service-managed conversation." @@ -1167,18 +1175,34 @@ async def _prepare_run_context( input_messages = normalize_messages(messages) - # `store` in runtime or agent options takes precedence over client-level storage - # indicators. An explicit `store=False` forces local (in-memory) history injection, - # even if the client is configured to use service-side storage by default. - store_ = opts.get("store", self.default_options.get("store", getattr(self.client, "STORES_BY_DEFAULT", False))) + # Combine agent-level defaults with runtime options up front so the decisions below read + # `store` from a single place rather than introspecting both dicts. _merge_options applies + # the same precedence used for the actual client call (runtime wins; unset/None falls back + # to the agent default). + effective_options = _merge_options(self.default_options, opts) + + # `store` in runtime or agent options takes precedence over the client's default + # storage behavior. An explicit `store=False` forces local (in-memory) history + # injection even when the client stores server-side by default; an explicit + # `store=True` forces service-side storage. A `store=None`/unset value means the + # service falls back to its own default. + explicit_store = effective_options.get("store") + # Internal behavior hint: will the service own history for this run? Only when the + # user left `store` unset do we fall back to the client's STORES_BY_DEFAULT. + service_stores_history = ( + explicit_store if explicit_store is not None else getattr(self.client, "STORES_BY_DEFAULT", False) + ) + # Resolve conversation_id from the same combined view so an agent-level default is honored + # when the runtime omits it (a live session id still takes precedence below). + effective_conversation_id = effective_options.get("conversation_id") # Auto-inject InMemoryHistoryProvider when session is provided, no context providers # registered, and no service-side storage indicators if ( session is not None and not self.context_providers and not session.service_session_id - and not opts.get("conversation_id") - and not store_ + and not effective_conversation_id + and not service_stores_history ): self.context_providers.append(InMemoryHistoryProvider()) @@ -1188,10 +1212,30 @@ async def _prepare_run_context( per_service_call_history_providers = self._resolve_per_service_call_history_providers( session=active_session, - options=opts, - service_stores_history=bool(store_), + conversation_id=effective_conversation_id, + service_stores_history=service_stores_history, ) + # When require_per_service_call_history_persistence is set together with a + # HistoryProvider, the per-service-call middleware (installed below) always persists + # the provider. ``service_stores_history`` only selects how the middleware behaves: + # - service does not store: the middleware also loads providers and drives the function + # loop with a local sentinel conversation id, or + # - service stores: the middleware skips loading (the service owns history) and simply + # persists each service call while the real conversation id flows through. + # In the service-managed case loading is skipped, so warn for providers that expect to load. + history_providers = self._get_history_providers() + if self.require_per_service_call_history_persistence and history_providers and service_stores_history: + for provider in history_providers: + if provider.load_messages: + logger.warning( + "HistoryProvider '%s' has load_messages=True but the chat client stores history " + "server-side; skipping local history load and relying on the service-managed " + "conversation. Set store=False to load from the provider, or load_messages=False " + "to silence this warning.", + provider.source_id, + ) + session_context, chat_options = await self._prepare_session_and_messages( session=active_session, input_messages=input_messages, @@ -1265,8 +1309,8 @@ async def _prepare_run_context( } if model is not None: run_opts["model"] = model - # Remove None values and merge with chat_options - run_opts = {k: v for k, v in run_opts.items() if v is not None} + # _merge_options strips unset (None) options, so e.g. an unset `store` is not forwarded + # and the service decides its own default. co = _merge_options(chat_options, run_opts) # Build session_messages from session context: context messages + input messages @@ -1280,6 +1324,7 @@ async def _prepare_run_context( agent=self, session=active_session, providers=per_service_call_history_providers, + service_stores_history=service_stores_history, ) existing_middleware = effective_client_kwargs.get("middleware") if isinstance(existing_middleware, Sequence) and not isinstance(existing_middleware, (str, bytes)): @@ -1319,7 +1364,7 @@ async def _prepare_run_context( "input_messages": input_messages, "session_messages": session_messages, "agent_name": agent_name, - "suppress_response_id": bool(per_service_call_history_providers), + "suppress_response_id": bool(per_service_call_history_providers) and not service_stores_history, "chat_options": co, "compaction_strategy": compaction_strategy or self.compaction_strategy, "tokenizer": tokenizer or self.tokenizer, @@ -1413,11 +1458,15 @@ async def _prepare_session_and_messages( options=options or {}, ) + # When per-service-call persistence is enabled, the per-service-call middleware owns + # HistoryProvider loading (it loads locally when the service does not store history, or + # relies on the service when it does), so skip them on the once-per-run before_run path. per_service_call_history_required = self.require_per_service_call_history_persistence and bool( self._get_history_providers() ) - # Run before_run providers (forward order, skip HistoryProvider when per-service-call persistence owns history) + # Run before_run providers (forward order, skip HistoryProvider when per-service-call + # persistence owns loading) for provider in self.context_providers: if per_service_call_history_required and isinstance(provider, HistoryProvider): continue diff --git a/python/packages/core/agent_framework/_clients.py b/python/packages/core/agent_framework/_clients.py index ddd765e654..2048dee71d 100644 --- a/python/packages/core/agent_framework/_clients.py +++ b/python/packages/core/agent_framework/_clients.py @@ -597,10 +597,12 @@ def as_agent( and dict literals are accepted without specialized option typing. context_providers: Context providers to include during agent invocation. middleware: List of middleware to intercept agent and function invocations. - require_per_service_call_history_persistence: Whether to require per-service-call - chat history persistence. When enabled, history providers are invoked around - each model call instead of once per ``run()`` when the service is not already - storing history. + require_per_service_call_history_persistence: When enabled (and a HistoryProvider is + present), the provider always persists history after each model call. If the + client does not store history server-side, history providers are also loaded and + injected around each model call; if it does, provider loading is skipped and the + service-managed conversation is the source of truth (persistence still happens + after each model call). function_invocation_configuration: Optional function invocation configuration override. compaction_strategy: Optional agent-level compaction override. When omitted, client-level compaction defaults remain in effect for each call. diff --git a/python/packages/core/agent_framework/_sessions.py b/python/packages/core/agent_framework/_sessions.py index be4d4ea285..181b6416e9 100644 --- a/python/packages/core/agent_framework/_sessions.py +++ b/python/packages/core/agent_framework/_sessions.py @@ -580,6 +580,7 @@ def __init__( agent: SupportsAgentRun, session: AgentSession, providers: Sequence[HistoryProvider], + service_stores_history: bool = False, ) -> None: """Initialize the middleware. @@ -587,10 +588,16 @@ def __init__( agent: The agent that owns the history providers. session: The active session for the current run. providers: The history providers participating in per-service-call persistence. + service_stores_history: When True, the chat client stores history server-side. The + middleware then skips loading providers and leaves the real conversation id + untouched, persisting each service call without driving the function loop with a + local sentinel. When False, the middleware loads providers and uses a local + sentinel conversation id so the function loop runs without service-side storage. """ self._agent = agent self._session = session self._providers = list(providers) + self._service_stores_history = service_stores_history async def _prepare_service_call_context(self, messages: Sequence[Message]) -> SessionContext: """Create a per-call SessionContext and load history providers into it.""" @@ -602,6 +609,9 @@ async def _prepare_service_call_context(self, messages: Sequence[Message]) -> Se ) for source_id, source_messages in context_messages.items(): service_call_context.extend_messages(source_id, source_messages) + # When the service stores history, it owns loading; the providers are write-only sinks. + if self._service_stores_history: + return service_call_context for provider in self._providers: if not provider.load_messages: continue @@ -652,7 +662,11 @@ async def _finalize_response( response: ChatResponse, ) -> ChatResponse: """Persist a model response and apply the local follow-up sentinel when needed.""" - if response.conversation_id is not None and not is_local_history_conversation_id(response.conversation_id): + if ( + not self._service_stores_history + and response.conversation_id is not None + and not is_local_history_conversation_id(response.conversation_id) + ): raise ChatClientInvalidResponseException( "require_per_service_call_history_persistence cannot be used " "when the chat client returns a real conversation_id." @@ -662,7 +676,9 @@ async def _finalize_response( service_call_context=service_call_context, response=response, ) - if _response_contains_follow_up_request(response): + # The local sentinel only applies when the service does not store history; when it does, + # the real conversation id already drives function-loop continuation. + if not self._service_stores_history and _response_contains_follow_up_request(response): response.mark_internal_conversation_id() response.conversation_id = LOCAL_HISTORY_CONVERSATION_ID return response @@ -681,8 +697,12 @@ async def process(self, context: ChatContext, call_next: Callable[[], Awaitable[ result type for streaming or non-streaming execution. """ service_call_context = await self._prepare_service_call_context(context.messages) - context.messages = service_call_context.get_messages(include_input=True) - self._strip_local_conversation_id(context) + # When the service stores history, leave the outgoing messages and the real conversation + # id untouched (pass-through); the middleware only persists. Otherwise reconstruct the + # outgoing messages from the loaded local history and strip the local sentinel. + if not self._service_stores_history: + context.messages = service_call_context.get_messages(include_input=True) + self._strip_local_conversation_id(context) await call_next() diff --git a/python/packages/core/tests/core/test_agents.py b/python/packages/core/tests/core/test_agents.py index f8e460e127..2bb3a9c6e6 100644 --- a/python/packages/core/tests/core/test_agents.py +++ b/python/packages/core/tests/core/test_agents.py @@ -3,6 +3,7 @@ import contextlib import inspect import json +import logging from collections.abc import AsyncIterable, Awaitable, Callable, MutableSequence, Sequence from typing import Any, cast from unittest.mock import AsyncMock, MagicMock, patch @@ -42,6 +43,8 @@ from agent_framework._middleware import FunctionInvocationContext from agent_framework.exceptions import AgentInvalidRequestException, ChatClientInvalidResponseException +from .conftest import MockBaseChatClient + class _FixedTokenizer: def __init__(self, token_count: int) -> None: @@ -609,6 +612,7 @@ def lookup_weather(location: str) -> str: async def test_per_service_call_persistence_uses_real_service_storage_when_client_stores_by_default( chat_client_base: SupportsChatGetResponse, + caplog: pytest.LogCaptureFixture, ) -> None: provider = _RecordingHistoryProvider() @@ -649,15 +653,22 @@ def lookup_weather(location: str) -> str: require_per_service_call_history_persistence=True, ) - result = await agent.run("What's the weather in Seattle?", session=session) + with caplog.at_level(logging.WARNING, logger="agent_framework"): + result = await agent.run("What's the weather in Seattle?", session=session) provider_state = session.state[provider.source_id] assert result.text == "It is sunny in Seattle." assert result.response_id == "resp_call_2" assert chat_client_base.call_count == 2 + # The service owns the conversation, so the provider never loads (issue #5798). assert "get_call_count" not in provider_state - assert "save_call_count" not in provider_state + # Persistence is owned by the per-service-call middleware: it persists once per service call + # (issue #5798: the provider must never be silently bypassed when the service stores history). + # This run makes two service calls (function call + final answer), so it persists twice. + assert provider_state["save_call_count"] == 2 + # load_messages=True while the service stores history surfaces a warning. + assert any("load_messages" in record.message for record in caplog.records) assert session.service_session_id == "resp_service_managed" @@ -1996,6 +2007,19 @@ def test_merge_options_none_values_ignored(): assert result["key2"] == "value2" +def test_merge_options_drops_none_base_values(): + """Test _merge_options strips None values so unset options are never forwarded.""" + base = {"store": None, "temperature": 0.5} + override = {"top_p": 0.9} + + result = _merge_options(base, override) + + # An unset base value (e.g. store=None from default_options) must not survive the merge. + assert "store" not in result + assert result["temperature"] == 0.5 + assert result["top_p"] == 0.9 + + def test_merge_options_runtime_model_overrides_default_model() -> None: """Test _merge_options lets a runtime model override a default model.""" result = _merge_options({"model": "default-model"}, {"model": "runtime-model"}) @@ -2658,3 +2682,344 @@ async def test_as_tool_raises_on_user_input_request(client: SupportsChatGetRespo assert len(exc_info.value.contents) == 1 assert exc_info.value.contents[0].type == "oauth_consent_request" assert exc_info.value.contents[0].consent_link == "https://login.microsoftonline.com/consent" + + +# region Per-service-call history persistence scenario matrix +# +# The driving field is ``require_per_service_call_history_persistence``. Every scenario runs a +# single agent run that makes **two service calls** -- a function call followed by a final +# completion -- so the *timing* of persistence is observable: +# +# * When the flag is ``True``, the per-service-call middleware persists the provider **after each +# service call**. So the function-call turn is already saved by the time the second (final) +# service call starts. This holds regardless of whether the chat client stores history +# server-side (the bug in issue #5798 was that a storing client silently bypassed persistence). +# * When the flag is ``False``, the provider persists **once, at the end of the run** -- nothing is +# saved between the two service calls. +# +# ``SpyChatClient.saves_before_call`` records ``provider.save_calls`` at the start of every service +# call, so ``[0, 1]`` means "the function-call turn was persisted before the final call" and +# ``[0, 0]`` means "no persistence happened mid-run". The client's ``store`` / ``STORES_BY_DEFAULT`` +# only selects *how* the middleware behaves -- never *whether* the provider persists. + +_PSC_SERVICE_CONVERSATION_ID = "svc-conversation" + +_psc_stream_params = pytest.mark.parametrize("stream", [False, True], ids=["sync", "stream"]) + + +@tool(name="lookup_weather", approval_mode="never_require") +def _psc_lookup_weather(location: str) -> str: + return f"Weather in {location}: sunny" + + +def _psc_function_call_script() -> list[tuple[str, ...]]: + """A fresh function-call-then-final-completion script (the client mutates it).""" + return [ + ("call", "call_1", "lookup_weather", '{"location": "Seattle"}'), + ("text", "It is sunny in Seattle."), + ] + + +class _PscSpyHistoryProvider(HistoryProvider): + """In-memory history provider that records load/save calls for assertions.""" + + def __init__(self, source_id: str = "spy_history", **kwargs: Any) -> None: + super().__init__(source_id, **kwargs) + self._messages: list[Message] = [] + self.get_calls: int = 0 + self.save_calls: int = 0 + self.saved_batches: list[list[Message]] = [] + + async def get_messages( + self, session_id: str | None, *, state: dict[str, Any] | None = None, **kwargs: Any + ) -> list[Message]: + self.get_calls += 1 + return list(self._messages) + + async def save_messages( + self, + session_id: str | None, + messages: Sequence[Message], + *, + state: dict[str, Any] | None = None, + **kwargs: Any, + ) -> None: + self.save_calls += 1 + self.saved_batches.append(list(messages)) + self._messages.extend(messages) + + @property + def stored_messages(self) -> list[Message]: + return list(self._messages) + + +class _PscSpyChatClient(MockBaseChatClient): + """Chat client that scripts a function-call/final-completion sequence. + + It records, at the start of each service call, how many provider saves have already happened + (``saves_before_call``), what messages it received, and what options it saw. When the effective + ``store`` is truthy it returns a stable ``conversation_id`` to mimic a server-managed + conversation, so the framework propagates ``session.service_session_id``. + """ + + def __init__( + self, + *, + provider: _PscSpyHistoryProvider, + stores_by_default: bool = False, + script: list[tuple[str, ...]] | None = None, + **kwargs: Any, + ) -> None: + super().__init__(**kwargs) + self.STORES_BY_DEFAULT = stores_by_default # type: ignore[attr-defined] + self._provider = provider + self._script = list(script) if script is not None else [("text", "ok")] + self.received_messages: list[list[Message]] = [] + self.received_options: list[dict[str, Any]] = [] + self.saves_before_call: list[int] = [] + + def _effective_store(self, options: dict[str, Any]) -> bool: + store = options.get("store") + if store is None: + return bool(self.STORES_BY_DEFAULT) + return bool(store) + + def _next_contents(self) -> list[Content]: + turn = self._script.pop(0) if self._script else ("text", "ok") + if turn[0] == "call": + _, call_id, name, args = turn + return [Content.from_function_call(call_id=call_id, name=name, arguments=args)] + return [Content.from_text(turn[1])] + + def _inner_get_response( # type: ignore[override] + self, + *, + messages: MutableSequence[Message], + stream: bool, + options: dict[str, Any], + **kwargs: Any, + ) -> Awaitable[ChatResponse] | ResponseStream[ChatResponseUpdate, ChatResponse]: + self.received_messages.append(list(messages)) + self.received_options.append(dict(options)) + self.saves_before_call.append(self._provider.save_calls) + conv_id = _PSC_SERVICE_CONVERSATION_ID if self._effective_store(options) else None + contents = self._next_contents() + + if stream: + + async def _stream() -> AsyncIterable[ChatResponseUpdate]: + self.call_count += 1 + yield ChatResponseUpdate( + contents=contents, + role="assistant", + finish_reason="stop", + conversation_id=conv_id, + ) + + def _finalize(updates: Sequence[ChatResponseUpdate]) -> ChatResponse: + response = ChatResponse.from_updates(updates, output_format_type=options.get("response_format")) + if conv_id: + response.conversation_id = conv_id + return response + + return ResponseStream(_stream(), finalizer=_finalize) + + async def _get() -> ChatResponse: + self.call_count += 1 + return ChatResponse( + messages=Message(role="assistant", contents=contents), + conversation_id=conv_id, + ) + + return _get() + + +def _psc_build_agent( + client: _PscSpyChatClient, + provider: _PscSpyHistoryProvider, + *, + require_per_service_call_history_persistence: bool, + default_options: dict[str, Any] | None = None, +) -> Agent: + kwargs: dict[str, Any] = {} + if default_options is not None: + kwargs["default_options"] = default_options + return Agent( + client=client, + tools=[_psc_lookup_weather], + context_providers=[provider], + require_per_service_call_history_persistence=require_per_service_call_history_persistence, + **kwargs, + ) + + +async def _psc_run(agent: Agent, text: str, session: AgentSession, *, stream: bool) -> str: + if stream: + chunks: list[str] = [] + async for update in agent.run(text, session=session, stream=True): + chunks.append(update.text or "") + return "".join(chunks) + result = await agent.run(text, session=session) + return result.text + + +# driver=True (the contract under test): persistence happens per service call + + +@_psc_stream_params +async def test_psc_flag_on_store_false_persists_after_each_service_call(stream: bool) -> None: + """Mode A (flag on, service does not store): function-call turn is persisted before the final call.""" + provider = _PscSpyHistoryProvider() + client = _PscSpyChatClient(provider=provider, stores_by_default=False, script=_psc_function_call_script()) + agent = _psc_build_agent(client, provider, require_per_service_call_history_persistence=True) + session = agent.create_session() + + text = await _psc_run(agent, "What's the weather in Seattle?", session, stream=stream) + + assert text == "It is sunny in Seattle." + # Two service calls: function call, then final completion. + assert client.call_count == 2 + # The contract: the function-call turn was persisted *before* the second service call started. + assert client.saves_before_call == [0, 1] + assert provider.save_calls == 2 + # Mode A loads local history (the middleware injects it before each service call). + assert provider.get_calls >= 1 + # No service-side storage, so no conversation id is propagated. + assert session.service_session_id is None + + +@_psc_stream_params +async def test_psc_flag_on_stores_by_default_persists_after_each_service_call( + stream: bool, caplog: pytest.LogCaptureFixture +) -> None: + """Mode B (flag on, service stores by default): still persists per service call, but skips load (issue #5798).""" + provider = _PscSpyHistoryProvider() # load_messages=True by default + client = _PscSpyChatClient(provider=provider, stores_by_default=True, script=_psc_function_call_script()) + agent = _psc_build_agent(client, provider, require_per_service_call_history_persistence=True) + session = agent.create_session() + + with caplog.at_level(logging.WARNING, logger="agent_framework"): + text = await _psc_run(agent, "What's the weather in Seattle?", session, stream=stream) + + assert text == "It is sunny in Seattle." + assert client.call_count == 2 + # The invariant the bug violated: persistence still happens per service call when the service stores. + assert client.saves_before_call == [0, 1] + assert provider.save_calls == 2 + # The service owns loading, so the provider is never asked to load. + assert provider.get_calls == 0 + # A warning surfaces the bypassed load (load_messages=True). + assert any("load_messages" in record.message for record in caplog.records) + # The real service conversation id propagates to the session. + assert session.service_session_id == _PSC_SERVICE_CONVERSATION_ID + + +@_psc_stream_params +async def test_psc_flag_on_store_only_provider_no_load_no_warning( + stream: bool, caplog: pytest.LogCaptureFixture +) -> None: + """Mode B with a store-only provider (load_messages=False): persists per call, no load, no warning.""" + provider = _PscSpyHistoryProvider(load_messages=False) + client = _PscSpyChatClient(provider=provider, stores_by_default=True, script=_psc_function_call_script()) + agent = _psc_build_agent(client, provider, require_per_service_call_history_persistence=True) + session = agent.create_session() + + with caplog.at_level(logging.WARNING, logger="agent_framework"): + await _psc_run(agent, "What's the weather in Seattle?", session, stream=stream) + + assert client.saves_before_call == [0, 1] + assert provider.save_calls == 2 + assert provider.get_calls == 0 + assert not any("load_messages" in record.message for record in caplog.records) + + +@_psc_stream_params +async def test_psc_flag_on_store_false_override_behaves_as_mode_a(stream: bool) -> None: + """Flag on + storing client but store=False override: falls back to Mode A (local, per call).""" + provider = _PscSpyHistoryProvider() + client = _PscSpyChatClient(provider=provider, stores_by_default=True, script=_psc_function_call_script()) + agent = _psc_build_agent( + client, provider, require_per_service_call_history_persistence=True, default_options={"store": False} + ) + session = agent.create_session() + + await _psc_run(agent, "What's the weather in Seattle?", session, stream=stream) + + assert client.saves_before_call == [0, 1] + assert provider.save_calls == 2 + assert provider.get_calls >= 1 + # store=False forces local handling, so no real service conversation id. + assert session.service_session_id is None + + +@_psc_stream_params +async def test_psc_flag_on_store_none_treated_as_absent(stream: bool, caplog: pytest.LogCaptureFixture) -> None: + """Flag on + storing client + explicit store=None: None is "unset", so the storing default applies (Mode B).""" + provider = _PscSpyHistoryProvider() + client = _PscSpyChatClient(provider=provider, stores_by_default=True, script=_psc_function_call_script()) + agent = _psc_build_agent( + client, provider, require_per_service_call_history_persistence=True, default_options={"store": None} + ) + session = agent.create_session() + + with caplog.at_level(logging.WARNING, logger="agent_framework"): + await _psc_run(agent, "What's the weather in Seattle?", session, stream=stream) + + assert client.saves_before_call == [0, 1] + assert provider.save_calls == 2 + assert provider.get_calls == 0 + assert session.service_session_id == _PSC_SERVICE_CONVERSATION_ID + assert any("load_messages" in record.message for record in caplog.records) + # store=None must not be forwarded to the client; the service decides its own default. + assert all("store" not in options for options in client.received_options) + + +@_psc_stream_params +async def test_psc_flag_on_respects_store_outputs_flag(stream: bool) -> None: + """Flag on: the provider's store_inputs/store_outputs flags still apply per service call.""" + provider = _PscSpyHistoryProvider(store_inputs=True, store_outputs=False) + client = _PscSpyChatClient(provider=provider, stores_by_default=True, script=_psc_function_call_script()) + agent = _psc_build_agent(client, provider, require_per_service_call_history_persistence=True) + session = agent.create_session() + + await _psc_run(agent, "What's the weather in Seattle?", session, stream=stream) + + assert provider.save_calls == 2 + # Outputs disabled, so no assistant/tool-call messages were stored, only user/tool inputs. + assert provider.stored_messages + assert all(message.role != "assistant" for message in provider.stored_messages) + + +# driver=False (control): persistence happens once, at the end of the run + + +@_psc_stream_params +async def test_psc_flag_off_store_false_persists_once_at_end(stream: bool) -> None: + """Flag off + non-storing client: nothing is persisted mid-run; one save at the end.""" + provider = _PscSpyHistoryProvider() + client = _PscSpyChatClient(provider=provider, stores_by_default=False, script=_psc_function_call_script()) + agent = _psc_build_agent(client, provider, require_per_service_call_history_persistence=False) + session = agent.create_session() + + text = await _psc_run(agent, "What's the weather in Seattle?", session, stream=stream) + + assert text == "It is sunny in Seattle." + assert client.call_count == 2 + # The control contract: no save happened between the function call and the final completion. + assert client.saves_before_call == [0, 0] + assert provider.save_calls == 1 + + +@_psc_stream_params +async def test_psc_flag_off_stores_by_default_persists_once_at_end(stream: bool) -> None: + """Flag off + storing client: once-per-run persistence, and the service conversation id propagates.""" + provider = _PscSpyHistoryProvider() + client = _PscSpyChatClient(provider=provider, stores_by_default=True, script=_psc_function_call_script()) + agent = _psc_build_agent(client, provider, require_per_service_call_history_persistence=False) + session = agent.create_session() + + await _psc_run(agent, "What's the weather in Seattle?", session, stream=stream) + + assert client.saves_before_call == [0, 0] + assert provider.save_calls == 1 + assert session.service_session_id == _PSC_SERVICE_CONVERSATION_ID