-
Notifications
You must be signed in to change notification settings - Fork 1.9k
Python: Fix per-service-call history persistence with server-storing clients #6310
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -92,12 +92,16 @@ | |
| def _merge_options(base: dict[str, Any], override: dict[str, Any]) -> dict[str, Any]: | ||
| """Merge two options dicts, with override values taking precedence. | ||
|
|
||
| ``None`` is treated as "unset": ``None`` overrides are skipped so they don't clobber a base | ||
| value, and the merged result is stripped of any remaining ``None`` values in a final pass so | ||
| unset options are never forwarded (e.g. an unset ``store`` is left for the service to default). | ||
|
|
||
| Args: | ||
| base: The base options dict. | ||
| override: The override options dict (values take precedence). | ||
|
|
||
| Returns: | ||
| A new merged options dict. | ||
| A new merged options dict containing no ``None`` values. | ||
| """ | ||
| result = dict(base) | ||
|
|
||
|
|
@@ -123,7 +127,7 @@ def _merge_options(base: dict[str, Any], override: dict[str, Any]) -> dict[str, | |
| result["instructions"] = f"{result['instructions']}\n{value}" | ||
| else: | ||
| result[key] = value | ||
| return result | ||
| return {key: value for key, value in result.items() if value is not None} | ||
|
|
||
|
|
||
| def _sanitize_agent_name(agent_name: str | None) -> str | None: | ||
|
|
@@ -460,6 +464,9 @@ async def _run_after_providers( | |
| if provider_session is None and self.context_providers: | ||
| provider_session = AgentSession() | ||
|
|
||
| # When per-service-call persistence is enabled, the per-service-call middleware owns | ||
| # HistoryProvider persistence (in both the local and service-managed cases), so skip | ||
| # them on the once-per-run path to avoid double persistence. | ||
| per_service_call_history_required = self.require_per_service_call_history_persistence and any( | ||
| isinstance(provider, HistoryProvider) for provider in self.context_providers | ||
| ) | ||
|
|
@@ -686,11 +693,14 @@ def __init__( | |
| description: A brief description of the agent's purpose. | ||
| context_providers: Context providers to include during agent invocation. | ||
| middleware: List of middleware to intercept agent and function invocations. | ||
| require_per_service_call_history_persistence: When True, history providers are invoked | ||
| around each model call instead of once per ``run()`` when the service | ||
| is not already storing history. If service-side storage is active for | ||
| the run, the agent skips local history providers and relies on the | ||
| service-managed conversation instead. | ||
| require_per_service_call_history_persistence: When True (and a HistoryProvider is | ||
| present), the provider always persists history via per-service-call middleware, | ||
| regardless of whether the client stores history server-side. If the client does | ||
| not store history, the middleware also loads providers around each model call and | ||
| drives the function loop with a local conversation; if it does, loading is skipped | ||
| (the service-managed conversation is the source of truth) and the middleware only | ||
| persists. A warning is logged for providers with ``load_messages=True`` when | ||
| loading is skipped because service-side storage is active. | ||
| default_options: A TypedDict containing chat options. When using a typed agent like | ||
| ``Agent[OpenAIChatOptions]``, this enables IDE autocomplete for | ||
| provider-specific options including temperature, max_tokens, model, | ||
|
|
@@ -791,22 +801,20 @@ def _resolve_per_service_call_history_providers( | |
| self, | ||
| *, | ||
| session: AgentSession | None, | ||
| options: Mapping[str, Any] | None, | ||
| conversation_id: str | None, | ||
| service_stores_history: bool, | ||
| ) -> list[HistoryProvider]: | ||
| history_providers = self._get_history_providers() | ||
| if not self.require_per_service_call_history_persistence or not history_providers: | ||
| return [] | ||
|
|
||
| conversation_id = ( | ||
| session.service_session_id | ||
| if session and session.service_session_id | ||
| else cast(str | None, (options or {}).get("conversation_id") or self.default_options.get("conversation_id")) | ||
| ) | ||
| if service_stores_history: | ||
| return [] | ||
|
|
||
| if conversation_id is not None: | ||
| # A live service-managed session id takes precedence over the resolved conversation id. | ||
| if session and session.service_session_id: | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Are we covering a second run on the same session in storing mode? This precedence branch only fires once |
||
| conversation_id = session.service_session_id | ||
| # Without service-side storage the middleware persists locally and drives the function | ||
| # loop with a local sentinel, which cannot be reconciled with an existing service-managed | ||
| # conversation. When the service stores history, an existing conversation id is expected. | ||
| if conversation_id is not None and not service_stores_history: | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Are we covering the allow side of this boundary? The PR adds |
||
| raise AgentInvalidRequestException( | ||
| "require_per_service_call_history_persistence cannot be used " | ||
| "with an existing service-managed conversation." | ||
|
|
@@ -1167,18 +1175,34 @@ async def _prepare_run_context( | |
|
|
||
| input_messages = normalize_messages(messages) | ||
|
|
||
| # `store` in runtime or agent options takes precedence over client-level storage | ||
| # indicators. An explicit `store=False` forces local (in-memory) history injection, | ||
| # even if the client is configured to use service-side storage by default. | ||
| store_ = opts.get("store", self.default_options.get("store", getattr(self.client, "STORES_BY_DEFAULT", False))) | ||
| # Combine agent-level defaults with runtime options up front so the decisions below read | ||
| # `store` from a single place rather than introspecting both dicts. _merge_options applies | ||
| # the same precedence used for the actual client call (runtime wins; unset/None falls back | ||
| # to the agent default). | ||
| effective_options = _merge_options(self.default_options, opts) | ||
|
|
||
| # `store` in runtime or agent options takes precedence over the client's default | ||
| # storage behavior. An explicit `store=False` forces local (in-memory) history | ||
| # injection even when the client stores server-side by default; an explicit | ||
| # `store=True` forces service-side storage. A `store=None`/unset value means the | ||
| # service falls back to its own default. | ||
| explicit_store = effective_options.get("store") | ||
| # Internal behavior hint: will the service own history for this run? Only when the | ||
| # user left `store` unset do we fall back to the client's STORES_BY_DEFAULT. | ||
| service_stores_history = ( | ||
| explicit_store if explicit_store is not None else getattr(self.client, "STORES_BY_DEFAULT", False) | ||
| ) | ||
| # Resolve conversation_id from the same combined view so an agent-level default is honored | ||
| # when the runtime omits it (a live session id still takes precedence below). | ||
| effective_conversation_id = effective_options.get("conversation_id") | ||
| # Auto-inject InMemoryHistoryProvider when session is provided, no context providers | ||
| # registered, and no service-side storage indicators | ||
| if ( | ||
| session is not None | ||
| and not self.context_providers | ||
| and not session.service_session_id | ||
| and not opts.get("conversation_id") | ||
| and not store_ | ||
| and not effective_conversation_id | ||
| and not service_stores_history | ||
| ): | ||
| self.context_providers.append(InMemoryHistoryProvider()) | ||
|
|
||
|
|
@@ -1188,10 +1212,30 @@ async def _prepare_run_context( | |
|
|
||
| per_service_call_history_providers = self._resolve_per_service_call_history_providers( | ||
| session=active_session, | ||
| options=opts, | ||
| service_stores_history=bool(store_), | ||
| conversation_id=effective_conversation_id, | ||
| service_stores_history=service_stores_history, | ||
| ) | ||
|
|
||
| # When require_per_service_call_history_persistence is set together with a | ||
| # HistoryProvider, the per-service-call middleware (installed below) always persists | ||
| # the provider. ``service_stores_history`` only selects how the middleware behaves: | ||
| # - service does not store: the middleware also loads providers and drives the function | ||
| # loop with a local sentinel conversation id, or | ||
| # - service stores: the middleware skips loading (the service owns history) and simply | ||
| # persists each service call while the real conversation id flows through. | ||
| # In the service-managed case loading is skipped, so warn for providers that expect to load. | ||
| history_providers = self._get_history_providers() | ||
| if self.require_per_service_call_history_persistence and history_providers and service_stores_history: | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This sits in |
||
| for provider in history_providers: | ||
| if provider.load_messages: | ||
| logger.warning( | ||
| "HistoryProvider '%s' has load_messages=True but the chat client stores history " | ||
| "server-side; skipping local history load and relying on the service-managed " | ||
| "conversation. Set store=False to load from the provider, or load_messages=False " | ||
| "to silence this warning.", | ||
| provider.source_id, | ||
|
Comment on lines
+1232
to
+1236
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why not throw an exception at agent initialization? |
||
| ) | ||
|
|
||
| session_context, chat_options = await self._prepare_session_and_messages( | ||
| session=active_session, | ||
| input_messages=input_messages, | ||
|
|
@@ -1265,8 +1309,8 @@ async def _prepare_run_context( | |
| } | ||
| if model is not None: | ||
| run_opts["model"] = model | ||
| # Remove None values and merge with chat_options | ||
| run_opts = {k: v for k, v in run_opts.items() if v is not None} | ||
| # _merge_options strips unset (None) options, so e.g. an unset `store` is not forwarded | ||
| # and the service decides its own default. | ||
| co = _merge_options(chat_options, run_opts) | ||
|
|
||
| # Build session_messages from session context: context messages + input messages | ||
|
|
@@ -1280,6 +1324,7 @@ async def _prepare_run_context( | |
| agent=self, | ||
| session=active_session, | ||
| providers=per_service_call_history_providers, | ||
| service_stores_history=service_stores_history, | ||
| ) | ||
| existing_middleware = effective_client_kwargs.get("middleware") | ||
| if isinstance(existing_middleware, Sequence) and not isinstance(existing_middleware, (str, bytes)): | ||
|
|
@@ -1319,7 +1364,7 @@ async def _prepare_run_context( | |
| "input_messages": input_messages, | ||
| "session_messages": session_messages, | ||
| "agent_name": agent_name, | ||
| "suppress_response_id": bool(per_service_call_history_providers), | ||
| "suppress_response_id": bool(per_service_call_history_providers) and not service_stores_history, | ||
| "chat_options": co, | ||
| "compaction_strategy": compaction_strategy or self.compaction_strategy, | ||
| "tokenizer": tokenizer or self.tokenizer, | ||
|
|
@@ -1413,11 +1458,15 @@ async def _prepare_session_and_messages( | |
| options=options or {}, | ||
| ) | ||
|
|
||
| # When per-service-call persistence is enabled, the per-service-call middleware owns | ||
| # HistoryProvider loading (it loads locally when the service does not store history, or | ||
| # relies on the service when it does), so skip them on the once-per-run before_run path. | ||
| per_service_call_history_required = self.require_per_service_call_history_persistence and bool( | ||
| self._get_history_providers() | ||
| ) | ||
|
|
||
| # Run before_run providers (forward order, skip HistoryProvider when per-service-call persistence owns history) | ||
| # Run before_run providers (forward order, skip HistoryProvider when per-service-call | ||
| # persistence owns loading) | ||
| for provider in self.context_providers: | ||
| if per_service_call_history_required and isinstance(provider, HistoryProvider): | ||
| continue | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -580,17 +580,24 @@ def __init__( | |
| agent: SupportsAgentRun, | ||
| session: AgentSession, | ||
| providers: Sequence[HistoryProvider], | ||
| service_stores_history: bool = False, | ||
| ) -> None: | ||
| """Initialize the middleware. | ||
|
|
||
| Args: | ||
| agent: The agent that owns the history providers. | ||
| session: The active session for the current run. | ||
| providers: The history providers participating in per-service-call persistence. | ||
| service_stores_history: When True, the chat client stores history server-side. The | ||
| middleware then skips loading providers and leaves the real conversation id | ||
| untouched, persisting each service call without driving the function loop with a | ||
| local sentinel. When False, the middleware loads providers and uses a local | ||
| sentinel conversation id so the function loop runs without service-side storage. | ||
| """ | ||
| self._agent = agent | ||
| self._session = session | ||
| self._providers = list(providers) | ||
| self._service_stores_history = service_stores_history | ||
|
|
||
| async def _prepare_service_call_context(self, messages: Sequence[Message]) -> SessionContext: | ||
| """Create a per-call SessionContext and load history providers into it.""" | ||
|
|
@@ -602,6 +609,9 @@ async def _prepare_service_call_context(self, messages: Sequence[Message]) -> Se | |
| ) | ||
| for source_id, source_messages in context_messages.items(): | ||
| service_call_context.extend_messages(source_id, source_messages) | ||
| # When the service stores history, it owns loading; the providers are write-only sinks. | ||
| if self._service_stores_history: | ||
| return service_call_context | ||
| for provider in self._providers: | ||
| if not provider.load_messages: | ||
| continue | ||
|
|
@@ -652,7 +662,11 @@ async def _finalize_response( | |
| response: ChatResponse, | ||
| ) -> ChatResponse: | ||
| """Persist a model response and apply the local follow-up sentinel when needed.""" | ||
| if response.conversation_id is not None and not is_local_history_conversation_id(response.conversation_id): | ||
| if ( | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What happens when |
||
| not self._service_stores_history | ||
| and response.conversation_id is not None | ||
| and not is_local_history_conversation_id(response.conversation_id) | ||
| ): | ||
| raise ChatClientInvalidResponseException( | ||
| "require_per_service_call_history_persistence cannot be used " | ||
| "when the chat client returns a real conversation_id." | ||
|
|
@@ -662,7 +676,9 @@ async def _finalize_response( | |
| service_call_context=service_call_context, | ||
| response=response, | ||
| ) | ||
| if _response_contains_follow_up_request(response): | ||
| # The local sentinel only applies when the service does not store history; when it does, | ||
| # the real conversation id already drives function-loop continuation. | ||
| if not self._service_stores_history and _response_contains_follow_up_request(response): | ||
| response.mark_internal_conversation_id() | ||
| response.conversation_id = LOCAL_HISTORY_CONVERSATION_ID | ||
| return response | ||
|
|
@@ -681,8 +697,12 @@ async def process(self, context: ChatContext, call_next: Callable[[], Awaitable[ | |
| result type for streaming or non-streaming execution. | ||
| """ | ||
| service_call_context = await self._prepare_service_call_context(context.messages) | ||
| context.messages = service_call_context.get_messages(include_input=True) | ||
| self._strip_local_conversation_id(context) | ||
| # When the service stores history, leave the outgoing messages and the real conversation | ||
| # id untouched (pass-through); the middleware only persists. Otherwise reconstruct the | ||
| # outgoing messages from the loaded local history and strip the local sentinel. | ||
| if not self._service_stores_history: | ||
| context.messages = service_call_context.get_messages(include_input=True) | ||
| self._strip_local_conversation_id(context) | ||
|
|
||
| await call_next() | ||
|
|
||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: this becomes sort of confusing. How about when this is true and there isn't a HistoryProvider?