From 5d06e08f8218a889ecd178116e26937dd79bd10b Mon Sep 17 00:00:00 2001 From: Stefan Gersmann Date: Wed, 20 May 2026 13:30:30 +0200 Subject: [PATCH] fix(app-server): surface turn errors and process events --- codex/app_server/_async_client.py | 31 +++- codex/app_server/_async_threads.py | 25 ++++ codex/app_server/_sync_threads.py | 18 +++ tests/test_app_server_async_client.py | 197 +++++++++++++++++++++++++- tests/test_stream_interaction.py | 12 +- 5 files changed, 274 insertions(+), 9 deletions(-) diff --git a/codex/app_server/_async_client.py b/codex/app_server/_async_client.py index a255181..ad8cf52 100644 --- a/codex/app_server/_async_client.py +++ b/codex/app_server/_async_client.py @@ -22,7 +22,7 @@ from codex.app_server._async_threads import AsyncAppServerThread as AsyncAppServerThread from codex.app_server._async_threads import AsyncTurnStream as AsyncTurnStream from codex.app_server._async_threads import _ThreadClient -from codex.app_server._protocol_helpers import RequestHandler +from codex.app_server._protocol_helpers import Notification, RequestHandler from codex.app_server._session import _AsyncNotificationSubscription, _AsyncSession from codex.app_server.models import ( InitializeResult, @@ -107,6 +107,35 @@ def __init__(self, session: _AsyncSession) -> None: def subscribe(self, methods: Collection[str] | None = None) -> _AsyncNotificationSubscription: return self._session.subscribe_notifications(methods) + def subscribe_command_exec_output(self, process_id: str) -> _AsyncNotificationSubscription: + """Subscribe to `command/exec/outputDelta` notifications for one process id.""" + + def predicate(notification: Notification) -> bool: + return ( + isinstance(notification, protocol.CommandExecOutputDeltaNotificationModel) + and notification.params.processId == process_id + ) + + return self._session.subscribe_notifications( + {"command/exec/outputDelta"}, + predicate=predicate, + ) + + def subscribe_process_events(self, process_handle: str) -> _AsyncNotificationSubscription: + """Subscribe to `process/outputDelta` and `process/exited` for one process handle.""" + + def predicate(notification: Notification) -> bool: + if isinstance(notification, protocol.ProcessOutputDeltaNotificationModel): + return notification.params.processHandle == process_handle + if isinstance(notification, protocol.ProcessExitedNotificationModel): + return notification.params.processHandle == process_handle + return False + + return self._session.subscribe_notifications( + {"process/outputDelta", "process/exited"}, + predicate=predicate, + ) + class AsyncAppServerClient: """Async client for `codex app-server`.""" diff --git a/codex/app_server/_async_threads.py b/codex/app_server/_async_threads.py index 369dc48..354fae7 100644 --- a/codex/app_server/_async_threads.py +++ b/codex/app_server/_async_threads.py @@ -37,6 +37,7 @@ DEFAULT_REVIEW_DELIVERY = protocol.ReviewDelivery("inline") _TURN_STREAM_NOTIFICATION_METHODS = { + "error", "turn/started", "turn/completed", "turn/diff/updated", @@ -54,6 +55,7 @@ "item/reasoning/summaryPartAdded", "item/reasoning/textDelta", "item/commandExecution/outputDelta", + "item/commandExecution/terminalInteraction", "item/fileChange/outputDelta", "serverRequest/resolved", } @@ -132,6 +134,7 @@ def __init__( self.usage: protocol.ThreadTokenUsage | None = None self._item_index: dict[str, int] = {} self._text_deltas: list[str] = [] + self._retryable_error_notifications: list[protocol.ErrorNotificationModel] = [] self._done = False self._closed = False @@ -188,6 +191,16 @@ async def __anext__(self) -> Notification: raise StopAsyncIteration notification = await self._subscription.next() self._apply(notification) + if isinstance(notification, protocol.ErrorNotificationModel): + if not notification.params.willRetry: + self._done = True + await self.close() + error = notification.params.error + message = error.message + if error.additionalDetails is not None and error.additionalDetails != "": + message = f"{message}: {error.additionalDetails}" + raise AppServerTurnError(message) + self._retryable_error_notifications.append(notification) if isinstance(notification, protocol.TurnCompletedNotificationModel): self._done = True return notification @@ -274,6 +287,18 @@ def text_deltas(self) -> tuple[str, ...]: """Return the streamed agent text deltas received so far.""" return tuple(self._text_deltas) + @property + def retryable_error_notifications(self) -> tuple[protocol.ErrorNotificationModel, ...]: + """Return retryable turn error notifications received so far.""" + return tuple(self._retryable_error_notifications) + + @property + def retryable_errors(self) -> tuple[protocol.TurnError, ...]: + """Return retryable turn errors received so far.""" + return tuple( + notification.params.error for notification in self._retryable_error_notifications + ) + def _apply(self, notification: Notification) -> None: self._apply_text_delta(notification) self._apply_token_usage(notification) diff --git a/codex/app_server/_sync_threads.py b/codex/app_server/_sync_threads.py index 50a41c0..a65048b 100644 --- a/codex/app_server/_sync_threads.py +++ b/codex/app_server/_sync_threads.py @@ -25,6 +25,10 @@ def subscribe( methods: Collection[str] | None = None, ) -> _AsyncNotificationSubscription: ... + def subscribe_command_exec_output(self, process_id: str) -> _AsyncNotificationSubscription: ... + + def subscribe_process_events(self, process_handle: str) -> _AsyncNotificationSubscription: ... + class _AsyncTurnStreamLike(Protocol): initial_turn: protocol.Turn @@ -157,6 +161,20 @@ def subscribe(self, methods: Collection[str] | None = None) -> NotificationSubsc self._run, ) + def subscribe_command_exec_output(self, process_id: str) -> NotificationSubscription: + """Subscribe to `command/exec/outputDelta` notifications for one process id.""" + return NotificationSubscription( + self._async_events.subscribe_command_exec_output(process_id), + self._run, + ) + + def subscribe_process_events(self, process_handle: str) -> NotificationSubscription: + """Subscribe to `process/outputDelta` and `process/exited` for one process handle.""" + return NotificationSubscription( + self._async_events.subscribe_process_events(process_handle), + self._run, + ) + class TurnStream(_SyncRunner): """Synchronous iterator over protocol-native notifications for a single turn.""" diff --git a/tests/test_app_server_async_client.py b/tests/test_app_server_async_client.py index ca4c984..8106608 100644 --- a/tests/test_app_server_async_client.py +++ b/tests/test_app_server_async_client.py @@ -5,7 +5,7 @@ import pytest from codex.app_server._async_client import AsyncEventsClient, AsyncTurnStream -from codex.app_server.errors import AppServerProtocolError +from codex.app_server.errors import AppServerProtocolError, AppServerTurnError from codex.app_server.models import ReviewResult from codex.protocol import types as protocol @@ -34,6 +34,17 @@ def update_predicate(self, predicate: object) -> None: self.updated_predicate = predicate +class _QueuedSubscription(_FakeSubscription): + def __init__(self, notifications: list[protocol.Notification]) -> None: + super().__init__() + self._notifications = notifications + + async def next(self) -> protocol.Notification: + if not self._notifications: + raise StopAsyncIteration + return self._notifications.pop(0) + + class _FakeThread: def __init__(self) -> None: self.id = "thr-1" @@ -76,6 +87,94 @@ def test_async_events_client_subscribe_delegates_to_session() -> None: assert session.calls == [(["turn/completed"], None)] +def test_async_events_client_subscribe_command_exec_output_filters_by_process_id() -> None: + session = _FakeSession() + events = AsyncEventsClient(session) # type: ignore[arg-type] + + subscription = events.subscribe_command_exec_output("proc-1") + + assert subscription == "subscription" + methods, predicate = session.calls[0] + assert methods == {"command/exec/outputDelta"} + assert callable(predicate) + matching = protocol.CommandExecOutputDeltaNotificationModel.model_validate( + { + "method": "command/exec/outputDelta", + "params": { + "capReached": False, + "deltaBase64": "aGVsbG8K", + "processId": "proc-1", + "stream": "stdout", + }, + } + ) + other_process = protocol.CommandExecOutputDeltaNotificationModel.model_validate( + { + "method": "command/exec/outputDelta", + "params": { + "capReached": False, + "deltaBase64": "aGVsbG8K", + "processId": "proc-2", + "stream": "stdout", + }, + } + ) + + assert predicate(matching) is True + assert predicate(other_process) is False + + +def test_async_events_client_subscribe_process_events_filters_by_process_handle() -> None: + session = _FakeSession() + events = AsyncEventsClient(session) # type: ignore[arg-type] + + subscription = events.subscribe_process_events("proc-handle-1") + + assert subscription == "subscription" + methods, predicate = session.calls[0] + assert methods == {"process/outputDelta", "process/exited"} + assert callable(predicate) + output = protocol.ProcessOutputDeltaNotificationModel.model_validate( + { + "method": "process/outputDelta", + "params": { + "capReached": False, + "deltaBase64": "aGVsbG8K", + "processHandle": "proc-handle-1", + "stream": "stdout", + }, + } + ) + exited = protocol.ProcessExitedNotificationModel.model_validate( + { + "method": "process/exited", + "params": { + "exitCode": 0, + "processHandle": "proc-handle-1", + "stderr": "", + "stderrCapReached": False, + "stdout": "", + "stdoutCapReached": False, + }, + } + ) + other_handle = protocol.ProcessOutputDeltaNotificationModel.model_validate( + { + "method": "process/outputDelta", + "params": { + "capReached": False, + "deltaBase64": "aGVsbG8K", + "processHandle": "proc-handle-2", + "stream": "stdout", + }, + } + ) + + assert predicate(output) is True + assert predicate(exited) is True + assert predicate(other_handle) is False + + def test_async_turn_stream_scope_predicate_filters_by_thread_and_turn() -> None: predicate = AsyncTurnStream._scope_predicate("thr-1", "turn-1") @@ -339,6 +438,102 @@ async def scenario() -> None: asyncio.run(scenario()) +def test_async_turn_stream_raises_and_closes_on_non_retryable_error_notification() -> None: + error_notification = protocol.ErrorNotificationModel.model_validate( + { + "method": "error", + "params": { + "threadId": "thr-1", + "turnId": "turn-1", + "willRetry": False, + "error": { + "message": "model unavailable", + "additionalDetails": "try another model", + }, + }, + } + ) + + async def scenario() -> None: + subscription = _QueuedSubscription([error_notification]) + stream = AsyncTurnStream( + _FakeThread(), # type: ignore[arg-type] + subscription, # type: ignore[arg-type] + protocol.Turn.model_validate(_turn_payload(status="inProgress")), + ) + + with pytest.raises(AppServerTurnError, match="model unavailable: try another model"): + await stream.__anext__() + + assert subscription.closed is True + + asyncio.run(scenario()) + + +def test_async_turn_stream_yields_retryable_error_notification() -> None: + error_notification = protocol.ErrorNotificationModel.model_validate( + { + "method": "error", + "params": { + "threadId": "thr-1", + "turnId": "turn-1", + "willRetry": True, + "error": {"message": "temporary outage"}, + }, + } + ) + + async def scenario() -> None: + subscription = _QueuedSubscription([error_notification]) + stream = AsyncTurnStream( + _FakeThread(), # type: ignore[arg-type] + subscription, # type: ignore[arg-type] + protocol.Turn.model_validate(_turn_payload(status="inProgress")), + ) + + assert await stream.__anext__() == error_notification + assert subscription.closed is False + + asyncio.run(scenario()) + + +def test_async_turn_stream_wait_preserves_retryable_error_notifications() -> None: + error_notification = protocol.ErrorNotificationModel.model_validate( + { + "method": "error", + "params": { + "threadId": "thr-1", + "turnId": "turn-1", + "willRetry": True, + "error": { + "message": "temporary outage", + "additionalDetails": "retrying with fallback", + }, + }, + } + ) + turn_completed = protocol.TurnCompletedNotificationModel.model_validate( + { + "method": "turn/completed", + "params": {"threadId": "thr-1", "turn": _turn_payload(status="completed")}, + } + ) + + async def scenario() -> None: + subscription = _QueuedSubscription([error_notification, turn_completed]) + stream = AsyncTurnStream( + _FakeThread(), # type: ignore[arg-type] + subscription, # type: ignore[arg-type] + protocol.Turn.model_validate(_turn_payload(status="inProgress")), + ) + + assert await stream.wait() is stream + assert stream.retryable_error_notifications == (error_notification,) + assert stream.retryable_errors == (error_notification.params.error,) + + asyncio.run(scenario()) + + def test_async_turn_stream_raise_for_terminal_status_requires_completion() -> None: stream = AsyncTurnStream( _FakeThread(), # type: ignore[arg-type] diff --git a/tests/test_stream_interaction.py b/tests/test_stream_interaction.py index 6b708d5..f6d104b 100644 --- a/tests/test_stream_interaction.py +++ b/tests/test_stream_interaction.py @@ -323,13 +323,12 @@ def test_consume_two_turns_same_thread() -> None: client.close() -def test_turn_completed_not_in_notification_methods_does_not_hang() -> None: +def test_turn_stream_yields_terminal_interaction_and_does_not_hang() -> None: """Reproduce review-action handling when terminalInteraction is emitted. - codex-review-action handles `item/commandExecution/terminalInteraction`, but - `_TURN_STREAM_NOTIFICATION_METHODS` does not subscribe to that method. - This test verifies such events are effectively dropped for the turn stream and - do not cause `wait()` to hang. + codex-review-action handles `item/commandExecution/terminalInteraction`; + the turn stream should now deliver it as a typed protocol notification and + still allow a later `wait()` call to finish. """ client, transport = _make_sync_client() @@ -337,7 +336,6 @@ def test_turn_completed_not_in_notification_methods_does_not_hang() -> None: thread = client.start_thread() stream = thread.run("Terminal interaction compatibility") - # This method is intentionally *not* part of turn stream subscribed methods. transport.push( { "method": "item/commandExecution/terminalInteraction", @@ -374,8 +372,8 @@ def test_turn_completed_not_in_notification_methods_does_not_hang() -> None: task_complete, events = _consume_like_review_action(stream) assert task_complete is True - # If terminalInteraction had been routed to stream, we'd see an extra event. assert [type(event) for event in events] == [ + protocol.ItemCommandExecutionTerminalInteractionNotification, protocol.ItemCompletedNotificationModel, protocol.TurnCompletedNotificationModel, ]