Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 30 additions & 1 deletion codex/app_server/_async_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from codex.app_server._async_threads import AsyncAppServerThread as AsyncAppServerThread
from codex.app_server._async_threads import AsyncTurnStream as AsyncTurnStream
from codex.app_server._async_threads import _ThreadClient
from codex.app_server._protocol_helpers import RequestHandler
from codex.app_server._protocol_helpers import Notification, RequestHandler
from codex.app_server._session import _AsyncNotificationSubscription, _AsyncSession
from codex.app_server.models import (
InitializeResult,
Expand Down Expand Up @@ -107,6 +107,35 @@ def __init__(self, session: _AsyncSession) -> None:
def subscribe(self, methods: Collection[str] | None = None) -> _AsyncNotificationSubscription:
return self._session.subscribe_notifications(methods)

def subscribe_command_exec_output(self, process_id: str) -> _AsyncNotificationSubscription:
"""Subscribe to `command/exec/outputDelta` notifications for one process id."""

def predicate(notification: Notification) -> bool:
return (
isinstance(notification, protocol.CommandExecOutputDeltaNotificationModel)
and notification.params.processId == process_id
)

return self._session.subscribe_notifications(
{"command/exec/outputDelta"},
predicate=predicate,
)

def subscribe_process_events(self, process_handle: str) -> _AsyncNotificationSubscription:
"""Subscribe to `process/outputDelta` and `process/exited` for one process handle."""

def predicate(notification: Notification) -> bool:
if isinstance(notification, protocol.ProcessOutputDeltaNotificationModel):
return notification.params.processHandle == process_handle
if isinstance(notification, protocol.ProcessExitedNotificationModel):
return notification.params.processHandle == process_handle
return False

return self._session.subscribe_notifications(
{"process/outputDelta", "process/exited"},
predicate=predicate,
)


class AsyncAppServerClient:
"""Async client for `codex app-server`."""
Expand Down
25 changes: 25 additions & 0 deletions codex/app_server/_async_threads.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
DEFAULT_REVIEW_DELIVERY = protocol.ReviewDelivery("inline")

_TURN_STREAM_NOTIFICATION_METHODS = {
"error",
"turn/started",
"turn/completed",
"turn/diff/updated",
Expand All @@ -54,6 +55,7 @@
"item/reasoning/summaryPartAdded",
"item/reasoning/textDelta",
"item/commandExecution/outputDelta",
"item/commandExecution/terminalInteraction",
"item/fileChange/outputDelta",
"serverRequest/resolved",
}
Expand Down Expand Up @@ -132,6 +134,7 @@ def __init__(
self.usage: protocol.ThreadTokenUsage | None = None
self._item_index: dict[str, int] = {}
self._text_deltas: list[str] = []
self._retryable_error_notifications: list[protocol.ErrorNotificationModel] = []
self._done = False
self._closed = False

Expand Down Expand Up @@ -188,6 +191,16 @@ async def __anext__(self) -> Notification:
raise StopAsyncIteration
notification = await self._subscription.next()
self._apply(notification)
if isinstance(notification, protocol.ErrorNotificationModel):
if not notification.params.willRetry:
self._done = True
await self.close()
error = notification.params.error
message = error.message
if error.additionalDetails is not None and error.additionalDetails != "":
message = f"{message}: {error.additionalDetails}"
raise AppServerTurnError(message)
self._retryable_error_notifications.append(notification)
if isinstance(notification, protocol.TurnCompletedNotificationModel):
self._done = True
return notification
Expand Down Expand Up @@ -274,6 +287,18 @@ def text_deltas(self) -> tuple[str, ...]:
"""Return the streamed agent text deltas received so far."""
return tuple(self._text_deltas)

@property
def retryable_error_notifications(self) -> tuple[protocol.ErrorNotificationModel, ...]:
"""Return retryable turn error notifications received so far."""
return tuple(self._retryable_error_notifications)

@property
def retryable_errors(self) -> tuple[protocol.TurnError, ...]:
"""Return retryable turn errors received so far."""
return tuple(
notification.params.error for notification in self._retryable_error_notifications
)

def _apply(self, notification: Notification) -> None:
self._apply_text_delta(notification)
self._apply_token_usage(notification)
Expand Down
18 changes: 18 additions & 0 deletions codex/app_server/_sync_threads.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,10 @@ def subscribe(
methods: Collection[str] | None = None,
) -> _AsyncNotificationSubscription: ...

def subscribe_command_exec_output(self, process_id: str) -> _AsyncNotificationSubscription: ...

def subscribe_process_events(self, process_handle: str) -> _AsyncNotificationSubscription: ...


class _AsyncTurnStreamLike(Protocol):
initial_turn: protocol.Turn
Expand Down Expand Up @@ -157,6 +161,20 @@ def subscribe(self, methods: Collection[str] | None = None) -> NotificationSubsc
self._run,
)

def subscribe_command_exec_output(self, process_id: str) -> NotificationSubscription:
"""Subscribe to `command/exec/outputDelta` notifications for one process id."""
return NotificationSubscription(
self._async_events.subscribe_command_exec_output(process_id),
self._run,
)

def subscribe_process_events(self, process_handle: str) -> NotificationSubscription:
"""Subscribe to `process/outputDelta` and `process/exited` for one process handle."""
return NotificationSubscription(
self._async_events.subscribe_process_events(process_handle),
self._run,
)


class TurnStream(_SyncRunner):
"""Synchronous iterator over protocol-native notifications for a single turn."""
Expand Down
197 changes: 196 additions & 1 deletion tests/test_app_server_async_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import pytest

from codex.app_server._async_client import AsyncEventsClient, AsyncTurnStream
from codex.app_server.errors import AppServerProtocolError
from codex.app_server.errors import AppServerProtocolError, AppServerTurnError
from codex.app_server.models import ReviewResult
from codex.protocol import types as protocol

Expand Down Expand Up @@ -34,6 +34,17 @@ def update_predicate(self, predicate: object) -> None:
self.updated_predicate = predicate


class _QueuedSubscription(_FakeSubscription):
def __init__(self, notifications: list[protocol.Notification]) -> None:
super().__init__()
self._notifications = notifications

async def next(self) -> protocol.Notification:
if not self._notifications:
raise StopAsyncIteration
return self._notifications.pop(0)


class _FakeThread:
def __init__(self) -> None:
self.id = "thr-1"
Expand Down Expand Up @@ -76,6 +87,94 @@ def test_async_events_client_subscribe_delegates_to_session() -> None:
assert session.calls == [(["turn/completed"], None)]


def test_async_events_client_subscribe_command_exec_output_filters_by_process_id() -> None:
session = _FakeSession()
events = AsyncEventsClient(session) # type: ignore[arg-type]

subscription = events.subscribe_command_exec_output("proc-1")

assert subscription == "subscription"
methods, predicate = session.calls[0]
assert methods == {"command/exec/outputDelta"}
assert callable(predicate)
matching = protocol.CommandExecOutputDeltaNotificationModel.model_validate(
{
"method": "command/exec/outputDelta",
"params": {
"capReached": False,
"deltaBase64": "aGVsbG8K",
"processId": "proc-1",
"stream": "stdout",
},
}
)
other_process = protocol.CommandExecOutputDeltaNotificationModel.model_validate(
{
"method": "command/exec/outputDelta",
"params": {
"capReached": False,
"deltaBase64": "aGVsbG8K",
"processId": "proc-2",
"stream": "stdout",
},
}
)

assert predicate(matching) is True
assert predicate(other_process) is False


def test_async_events_client_subscribe_process_events_filters_by_process_handle() -> None:
session = _FakeSession()
events = AsyncEventsClient(session) # type: ignore[arg-type]

subscription = events.subscribe_process_events("proc-handle-1")

assert subscription == "subscription"
methods, predicate = session.calls[0]
assert methods == {"process/outputDelta", "process/exited"}
assert callable(predicate)
output = protocol.ProcessOutputDeltaNotificationModel.model_validate(
{
"method": "process/outputDelta",
"params": {
"capReached": False,
"deltaBase64": "aGVsbG8K",
"processHandle": "proc-handle-1",
"stream": "stdout",
},
}
)
exited = protocol.ProcessExitedNotificationModel.model_validate(
{
"method": "process/exited",
"params": {
"exitCode": 0,
"processHandle": "proc-handle-1",
"stderr": "",
"stderrCapReached": False,
"stdout": "",
"stdoutCapReached": False,
},
}
)
other_handle = protocol.ProcessOutputDeltaNotificationModel.model_validate(
{
"method": "process/outputDelta",
"params": {
"capReached": False,
"deltaBase64": "aGVsbG8K",
"processHandle": "proc-handle-2",
"stream": "stdout",
},
}
)

assert predicate(output) is True
assert predicate(exited) is True
assert predicate(other_handle) is False


def test_async_turn_stream_scope_predicate_filters_by_thread_and_turn() -> None:
predicate = AsyncTurnStream._scope_predicate("thr-1", "turn-1")

Expand Down Expand Up @@ -339,6 +438,102 @@ async def scenario() -> None:
asyncio.run(scenario())


def test_async_turn_stream_raises_and_closes_on_non_retryable_error_notification() -> None:
error_notification = protocol.ErrorNotificationModel.model_validate(
{
"method": "error",
"params": {
"threadId": "thr-1",
"turnId": "turn-1",
"willRetry": False,
"error": {
"message": "model unavailable",
"additionalDetails": "try another model",
},
},
}
)

async def scenario() -> None:
subscription = _QueuedSubscription([error_notification])
stream = AsyncTurnStream(
_FakeThread(), # type: ignore[arg-type]
subscription, # type: ignore[arg-type]
protocol.Turn.model_validate(_turn_payload(status="inProgress")),
)

with pytest.raises(AppServerTurnError, match="model unavailable: try another model"):
await stream.__anext__()

assert subscription.closed is True

asyncio.run(scenario())


def test_async_turn_stream_yields_retryable_error_notification() -> None:
error_notification = protocol.ErrorNotificationModel.model_validate(
{
"method": "error",
"params": {
"threadId": "thr-1",
"turnId": "turn-1",
"willRetry": True,
"error": {"message": "temporary outage"},
},
}
)

async def scenario() -> None:
subscription = _QueuedSubscription([error_notification])
stream = AsyncTurnStream(
_FakeThread(), # type: ignore[arg-type]
subscription, # type: ignore[arg-type]
protocol.Turn.model_validate(_turn_payload(status="inProgress")),
)

assert await stream.__anext__() == error_notification
assert subscription.closed is False

asyncio.run(scenario())


def test_async_turn_stream_wait_preserves_retryable_error_notifications() -> None:
error_notification = protocol.ErrorNotificationModel.model_validate(
{
"method": "error",
"params": {
"threadId": "thr-1",
"turnId": "turn-1",
"willRetry": True,
"error": {
"message": "temporary outage",
"additionalDetails": "retrying with fallback",
},
},
}
)
turn_completed = protocol.TurnCompletedNotificationModel.model_validate(
{
"method": "turn/completed",
"params": {"threadId": "thr-1", "turn": _turn_payload(status="completed")},
}
)

async def scenario() -> None:
subscription = _QueuedSubscription([error_notification, turn_completed])
stream = AsyncTurnStream(
_FakeThread(), # type: ignore[arg-type]
subscription, # type: ignore[arg-type]
protocol.Turn.model_validate(_turn_payload(status="inProgress")),
)

assert await stream.wait() is stream
assert stream.retryable_error_notifications == (error_notification,)
assert stream.retryable_errors == (error_notification.params.error,)

asyncio.run(scenario())


def test_async_turn_stream_raise_for_terminal_status_requires_completion() -> None:
stream = AsyncTurnStream(
_FakeThread(), # type: ignore[arg-type]
Expand Down
12 changes: 5 additions & 7 deletions tests/test_stream_interaction.py
Original file line number Diff line number Diff line change
Expand Up @@ -323,21 +323,19 @@ def test_consume_two_turns_same_thread() -> None:
client.close()


def test_turn_completed_not_in_notification_methods_does_not_hang() -> None:
def test_turn_stream_yields_terminal_interaction_and_does_not_hang() -> None:
"""Reproduce review-action handling when terminalInteraction is emitted.

codex-review-action handles `item/commandExecution/terminalInteraction`, but
`_TURN_STREAM_NOTIFICATION_METHODS` does not subscribe to that method.
This test verifies such events are effectively dropped for the turn stream and
do not cause `wait()` to hang.
codex-review-action handles `item/commandExecution/terminalInteraction`;
the turn stream should now deliver it as a typed protocol notification and
still allow a later `wait()` call to finish.
"""

client, transport = _make_sync_client()
try:
thread = client.start_thread()
stream = thread.run("Terminal interaction compatibility")

# This method is intentionally *not* part of turn stream subscribed methods.
transport.push(
{
"method": "item/commandExecution/terminalInteraction",
Expand Down Expand Up @@ -374,8 +372,8 @@ def test_turn_completed_not_in_notification_methods_does_not_hang() -> None:
task_complete, events = _consume_like_review_action(stream)

assert task_complete is True
# If terminalInteraction had been routed to stream, we'd see an extra event.
assert [type(event) for event in events] == [
protocol.ItemCommandExecutionTerminalInteractionNotification,
protocol.ItemCompletedNotificationModel,
protocol.TurnCompletedNotificationModel,
]
Expand Down
Loading