From afda03ef40c4bc69570077e7c9ce3d3ff2abe65a Mon Sep 17 00:00:00 2001 From: Kazuhiro Sera Date: Tue, 21 Apr 2026 12:48:30 +0900 Subject: [PATCH 1/2] refactor: share sandbox ephemeral mount lifecycle --- .../extensions/sandbox/cloudflare/sandbox.py | 97 ++----------- .../extensions/sandbox/vercel/sandbox.py | 112 ++------------- src/agents/sandbox/session/mount_lifecycle.py | 112 +++++++++++++++ tests/sandbox/test_mount_lifecycle.py | 134 ++++++++++++++++++ 4 files changed, 275 insertions(+), 180 deletions(-) create mode 100644 src/agents/sandbox/session/mount_lifecycle.py create mode 100644 tests/sandbox/test_mount_lifecycle.py diff --git a/src/agents/extensions/sandbox/cloudflare/sandbox.py b/src/agents/extensions/sandbox/cloudflare/sandbox.py index 281f6b4545..0454323ea2 100644 --- a/src/agents/extensions/sandbox/cloudflare/sandbox.py +++ b/src/agents/extensions/sandbox/cloudflare/sandbox.py @@ -47,6 +47,7 @@ from ....sandbox.session.base_sandbox_session import BaseSandboxSession from ....sandbox.session.dependencies import Dependencies from ....sandbox.session.manager import Instrumentation +from ....sandbox.session.mount_lifecycle import with_ephemeral_mounts_removed from ....sandbox.session.pty_types import ( PTY_PROCESSES_MAX, PTY_PROCESSES_WARNING, @@ -1204,91 +1205,23 @@ async def _hydrate_workspace_via_http(self, data: io.IOBase) -> None: async def persist_workspace(self) -> io.IOBase: root = self._workspace_root_path() - unmounted_mounts: list[tuple[Any, Path]] = [] - unmount_error: WorkspaceArchiveReadError | None = None - for mount_entry, mount_path in self.state.manifest.ephemeral_mount_targets(): - try: - await mount_entry.mount_strategy.teardown_for_snapshot( - mount_entry, self, mount_path - ) - except Exception as e: - unmount_error = WorkspaceArchiveReadError(path=root, cause=e) - break - unmounted_mounts.append((mount_entry, mount_path)) - - snapshot_error: WorkspaceArchiveReadError | None = None - persisted: io.IOBase | None = None - if unmount_error is None: - try: - persisted = await self._persist_workspace_via_http() - except WorkspaceArchiveReadError as e: - snapshot_error = e - - remount_error: WorkspaceArchiveReadError | None = None - for mount_entry, mount_path in reversed(unmounted_mounts): - try: - await mount_entry.mount_strategy.restore_after_snapshot( - mount_entry, self, mount_path - ) - except Exception as e: - if remount_error is None: - remount_error = WorkspaceArchiveReadError(path=root, cause=e) - - if remount_error is not None: - if snapshot_error is not None: - remount_error.context["snapshot_error_before_remount_corruption"] = { - "message": snapshot_error.message, - } - raise remount_error - if unmount_error is not None: - raise unmount_error - if snapshot_error is not None: - raise snapshot_error - - assert persisted is not None - return persisted + return await with_ephemeral_mounts_removed( + self, + self._persist_workspace_via_http, + error_path=root, + error_cls=WorkspaceArchiveReadError, + operation_error_context_key="snapshot_error_before_remount_corruption", + ) async def hydrate_workspace(self, data: io.IOBase) -> None: root = self._workspace_root_path() - unmounted_mounts: list[tuple[Any, Path]] = [] - unmount_error: WorkspaceArchiveWriteError | None = None - for mount_entry, mount_path in self.state.manifest.ephemeral_mount_targets(): - try: - await mount_entry.mount_strategy.teardown_for_snapshot( - mount_entry, self, mount_path - ) - except Exception as e: - unmount_error = WorkspaceArchiveWriteError(path=root, cause=e) - break - unmounted_mounts.append((mount_entry, mount_path)) - - hydrate_error: WorkspaceArchiveWriteError | None = None - if unmount_error is None: - try: - await self._hydrate_workspace_via_http(data) - except WorkspaceArchiveWriteError as e: - hydrate_error = e - - remount_error: WorkspaceArchiveWriteError | None = None - for mount_entry, mount_path in reversed(unmounted_mounts): - try: - await mount_entry.mount_strategy.restore_after_snapshot( - mount_entry, self, mount_path - ) - except Exception as e: - if remount_error is None: - remount_error = WorkspaceArchiveWriteError(path=root, cause=e) - - if remount_error is not None: - if hydrate_error is not None: - remount_error.context["hydrate_error_before_remount_corruption"] = { - "message": hydrate_error.message, - } - raise remount_error - if unmount_error is not None: - raise unmount_error - if hydrate_error is not None: - raise hydrate_error + await with_ephemeral_mounts_removed( + self, + lambda: self._hydrate_workspace_via_http(data), + error_path=root, + error_cls=WorkspaceArchiveWriteError, + operation_error_context_key="hydrate_error_before_remount_corruption", + ) class CloudflareSandboxClient(BaseSandboxClient[CloudflareSandboxClientOptions]): diff --git a/src/agents/extensions/sandbox/vercel/sandbox.py b/src/agents/extensions/sandbox/vercel/sandbox.py index 8deafc2b2a..c0041bd79b 100644 --- a/src/agents/extensions/sandbox/vercel/sandbox.py +++ b/src/agents/extensions/sandbox/vercel/sandbox.py @@ -17,7 +17,6 @@ import posixpath import tarfile import uuid -from collections.abc import Awaitable, Callable from pathlib import Path, PurePosixPath from typing import Any, Literal, cast from urllib.parse import urlsplit @@ -50,6 +49,7 @@ from ....sandbox.session.base_sandbox_session import BaseSandboxSession from ....sandbox.session.dependencies import Dependencies from ....sandbox.session.manager import Instrumentation +from ....sandbox.session.mount_lifecycle import with_ephemeral_mounts_removed from ....sandbox.session.runtime_helpers import RESOLVE_WORKSPACE_PATH_HELPER, RuntimeHelperScript from ....sandbox.session.sandbox_client import BaseSandboxClient, BaseSandboxClientOptions from ....sandbox.snapshot import SnapshotBase, SnapshotSpec, resolve_snapshot @@ -404,100 +404,6 @@ async def running(self) -> bool: async def shutdown(self) -> None: await self._stop_attached_sandbox() - async def _persist_with_ephemeral_mounts_removed( - self, - operation: Callable[[], Awaitable[io.IOBase]], - ) -> io.IOBase: - root = self._workspace_root_path() - unmounted_mounts: list[tuple[Any, Path]] = [] - unmount_error: WorkspaceArchiveReadError | None = None - for mount_entry, mount_path in self.state.manifest.ephemeral_mount_targets(): - try: - await mount_entry.mount_strategy.teardown_for_snapshot( - mount_entry, self, mount_path - ) - except Exception as exc: - unmount_error = WorkspaceArchiveReadError(path=root, cause=exc) - break - unmounted_mounts.append((mount_entry, mount_path)) - - persist_error: WorkspaceArchiveReadError | None = None - persisted: io.IOBase | None = None - if unmount_error is None: - try: - persisted = await operation() - except WorkspaceArchiveReadError as exc: - persist_error = exc - - remount_error: WorkspaceArchiveReadError | None = None - for mount_entry, mount_path in reversed(unmounted_mounts): - try: - await mount_entry.mount_strategy.restore_after_snapshot( - mount_entry, self, mount_path - ) - except Exception as exc: - if remount_error is None: - remount_error = WorkspaceArchiveReadError(path=root, cause=exc) - - if remount_error is not None: - if persist_error is not None: - remount_error.context["snapshot_error_before_remount_corruption"] = { - "message": persist_error.message - } - raise remount_error - if unmount_error is not None: - raise unmount_error - if persist_error is not None: - raise persist_error - - assert persisted is not None - return persisted - - async def _hydrate_with_ephemeral_mounts_removed( - self, - operation: Callable[[], Awaitable[None]], - ) -> None: - root = self._workspace_root_path() - unmounted_mounts: list[tuple[Any, Path]] = [] - unmount_error: WorkspaceArchiveWriteError | None = None - for mount_entry, mount_path in self.state.manifest.ephemeral_mount_targets(): - try: - await mount_entry.mount_strategy.teardown_for_snapshot( - mount_entry, self, mount_path - ) - except Exception as exc: - unmount_error = WorkspaceArchiveWriteError(path=root, cause=exc) - break - unmounted_mounts.append((mount_entry, mount_path)) - - hydrate_error: WorkspaceArchiveWriteError | None = None - if unmount_error is None: - try: - await operation() - except WorkspaceArchiveWriteError as exc: - hydrate_error = exc - - remount_error: WorkspaceArchiveWriteError | None = None - for mount_entry, mount_path in reversed(unmounted_mounts): - try: - await mount_entry.mount_strategy.restore_after_snapshot( - mount_entry, self, mount_path - ) - except Exception as exc: - if remount_error is None: - remount_error = WorkspaceArchiveWriteError(path=root, cause=exc) - - if remount_error is not None: - if hydrate_error is not None: - remount_error.context["hydrate_error_before_remount_corruption"] = { - "message": hydrate_error.message - } - raise remount_error - if unmount_error is not None: - raise unmount_error - if hydrate_error is not None: - raise hydrate_error - async def _exec_internal( self, *command: str | Path, @@ -601,7 +507,13 @@ async def write( raise WorkspaceArchiveWriteError(path=normalized_path, cause=exc) from exc async def persist_workspace(self) -> io.IOBase: - return await self._persist_with_ephemeral_mounts_removed(self._persist_workspace_internal) + return await with_ephemeral_mounts_removed( + self, + self._persist_workspace_internal, + error_path=self._workspace_root_path(), + error_cls=WorkspaceArchiveReadError, + operation_error_context_key="snapshot_error_before_remount_corruption", + ) async def _persist_workspace_internal(self) -> io.IOBase: if self.state.workspace_persistence == _WORKSPACE_PERSISTENCE_SNAPSHOT: @@ -665,8 +577,12 @@ async def hydrate_workspace(self, data: io.IOBase) -> None: actual_type=type(raw).__name__, ) - await self._hydrate_with_ephemeral_mounts_removed( - lambda: self._hydrate_workspace_internal(bytes(raw)) + await with_ephemeral_mounts_removed( + self, + lambda: self._hydrate_workspace_internal(bytes(raw)), + error_path=self._workspace_root_path(), + error_cls=WorkspaceArchiveWriteError, + operation_error_context_key="hydrate_error_before_remount_corruption", ) async def _hydrate_workspace_internal(self, raw: bytes) -> None: diff --git a/src/agents/sandbox/session/mount_lifecycle.py b/src/agents/sandbox/session/mount_lifecycle.py new file mode 100644 index 0000000000..bf32d82a17 --- /dev/null +++ b/src/agents/sandbox/session/mount_lifecycle.py @@ -0,0 +1,112 @@ +from __future__ import annotations + +from collections.abc import Awaitable, Callable +from pathlib import Path +from typing import TYPE_CHECKING, TypeAlias, TypeVar, cast + +from ..errors import ( + WorkspaceArchiveReadError, + WorkspaceArchiveWriteError, + WorkspaceIOError, +) + +if TYPE_CHECKING: + from ..entries import Mount + from .base_sandbox_session import BaseSandboxSession + +ArchiveError: TypeAlias = WorkspaceArchiveReadError | WorkspaceArchiveWriteError +ArchiveErrorClass: TypeAlias = type[WorkspaceArchiveReadError] | type[WorkspaceArchiveWriteError] + +_ResultT = TypeVar("_ResultT") +_MISSING = object() + + +async def with_ephemeral_mounts_removed( + session: BaseSandboxSession, + operation: Callable[[], Awaitable[_ResultT]], + *, + error_path: Path, + error_cls: ArchiveErrorClass, + operation_error_context_key: str | None, +) -> _ResultT: + detached_mounts: list[tuple[Mount, Path]] = [] + detach_error: ArchiveError | None = None + for mount_entry, mount_path in session.state.manifest.ephemeral_mount_targets(): + try: + await mount_entry.mount_strategy.teardown_for_snapshot(mount_entry, session, mount_path) + except Exception as exc: + detach_error = error_cls(path=error_path, cause=exc) + break + detached_mounts.append((mount_entry, mount_path)) + + operation_error: ArchiveError | None = None + operation_result: object = _MISSING + if detach_error is None: + try: + operation_result = await operation() + except WorkspaceIOError as exc: + if not isinstance(exc, error_cls): + raise + operation_error = cast(ArchiveError, exc) + + restore_error = await restore_detached_mounts( + session, + detached_mounts, + error_path=error_path, + error_cls=error_cls, + ) + + if restore_error is not None: + if operation_error is not None and operation_error_context_key is not None: + restore_error.context[operation_error_context_key] = { + "message": operation_error.message + } + raise restore_error + if detach_error is not None: + raise detach_error + if operation_error is not None: + raise operation_error + + assert operation_result is not _MISSING + return cast(_ResultT, operation_result) + + +async def restore_detached_mounts( + session: BaseSandboxSession, + detached_mounts: list[tuple[Mount, Path]], + *, + error_path: Path, + error_cls: ArchiveErrorClass, +) -> ArchiveError | None: + restore_error: ArchiveError | None = None + for mount_entry, mount_path in reversed(detached_mounts): + try: + await mount_entry.mount_strategy.restore_after_snapshot( + mount_entry, session, mount_path + ) + except Exception as exc: + current_error = error_cls(path=error_path, cause=exc) + if restore_error is None: + restore_error = current_error + else: + additional_errors = restore_error.context.setdefault( + "additional_remount_errors", [] + ) + assert isinstance(additional_errors, list) + additional_errors.append(workspace_archive_error_summary(current_error)) + return restore_error + + +def workspace_archive_error_summary(error: ArchiveError) -> dict[str, str]: + summary = {"message": error.message} + if error.cause is not None: + summary["cause_type"] = type(error.cause).__name__ + summary["cause"] = str(error.cause) + return summary + + +__all__ = [ + "restore_detached_mounts", + "with_ephemeral_mounts_removed", + "workspace_archive_error_summary", +] diff --git a/tests/sandbox/test_mount_lifecycle.py b/tests/sandbox/test_mount_lifecycle.py new file mode 100644 index 0000000000..917a7204ee --- /dev/null +++ b/tests/sandbox/test_mount_lifecycle.py @@ -0,0 +1,134 @@ +from __future__ import annotations + +from pathlib import Path +from typing import Any, cast + +import pytest + +from agents.sandbox.errors import WorkspaceArchiveReadError +from agents.sandbox.session.mount_lifecycle import with_ephemeral_mounts_removed + + +class _FakeMountStrategy: + def __init__( + self, + events: list[str], + *, + name: str, + fail_teardown: bool = False, + fail_restore: bool = False, + ) -> None: + self._events = events + self._name = name + self._fail_teardown = fail_teardown + self._fail_restore = fail_restore + + async def teardown_for_snapshot( + self, + mount: object, + session: object, + path: Path, + ) -> None: + _ = (mount, session, path) + self._events.append(f"teardown:{self._name}") + if self._fail_teardown: + raise RuntimeError(f"teardown failed: {self._name}") + + async def restore_after_snapshot( + self, + mount: object, + session: object, + path: Path, + ) -> None: + _ = (mount, session, path) + self._events.append(f"restore:{self._name}") + if self._fail_restore: + raise RuntimeError(f"restore failed: {self._name}") + + +class _FakeMount: + def __init__(self, strategy: _FakeMountStrategy) -> None: + self.mount_strategy = strategy + + +class _FakeManifest: + def __init__(self, mounts: list[tuple[_FakeMount, Path]]) -> None: + self._mounts = mounts + + def ephemeral_mount_targets(self) -> list[tuple[_FakeMount, Path]]: + return self._mounts + + +class _FakeState: + def __init__(self, manifest: _FakeManifest) -> None: + self.manifest = manifest + + +class _FakeSession: + def __init__(self, manifest: _FakeManifest) -> None: + self.state = _FakeState(manifest) + + +@pytest.mark.asyncio +async def test_with_ephemeral_mounts_removed_restores_in_reverse_order() -> None: + events: list[str] = [] + left = _FakeMount(_FakeMountStrategy(events, name="left")) + right = _FakeMount(_FakeMountStrategy(events, name="right")) + session = _FakeSession( + _FakeManifest( + [ + (left, Path("/workspace/left")), + (right, Path("/workspace/right")), + ] + ) + ) + + async def operation() -> str: + events.append("operation") + return "persisted" + + result = await with_ephemeral_mounts_removed( + cast(Any, session), + operation, + error_path=Path("/workspace"), + error_cls=WorkspaceArchiveReadError, + operation_error_context_key="snapshot_error_before_remount_corruption", + ) + + assert result == "persisted" + assert events == [ + "teardown:left", + "teardown:right", + "operation", + "restore:right", + "restore:left", + ] + + +@pytest.mark.asyncio +async def test_with_ephemeral_mounts_removed_reports_restore_error_after_operation_error() -> None: + events: list[str] = [] + mount = _FakeMount(_FakeMountStrategy(events, name="mount", fail_restore=True)) + session = _FakeSession(_FakeManifest([(mount, Path("/workspace/mount"))])) + + async def operation() -> bytes: + events.append("operation") + raise WorkspaceArchiveReadError( + path=Path("/workspace"), + context={"reason": "persist_failed"}, + ) + + with pytest.raises(WorkspaceArchiveReadError) as exc_info: + await with_ephemeral_mounts_removed( + cast(Any, session), + operation, + error_path=Path("/workspace"), + error_cls=WorkspaceArchiveReadError, + operation_error_context_key="snapshot_error_before_remount_corruption", + ) + + assert events == ["teardown:mount", "operation", "restore:mount"] + assert exc_info.value.context["snapshot_error_before_remount_corruption"] == { + "message": "failed to read archive for path: /workspace", + } + assert isinstance(exc_info.value.cause, RuntimeError) From 3246243e8a994f2a7a52ee4788da743d627c6620 Mon Sep 17 00:00:00 2001 From: Kazuhiro Sera Date: Tue, 21 Apr 2026 12:58:49 +0900 Subject: [PATCH 2/2] fix tests --- tests/sandbox/test_mount_lifecycle.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/tests/sandbox/test_mount_lifecycle.py b/tests/sandbox/test_mount_lifecycle.py index 917a7204ee..4fea072847 100644 --- a/tests/sandbox/test_mount_lifecycle.py +++ b/tests/sandbox/test_mount_lifecycle.py @@ -110,13 +110,14 @@ async def test_with_ephemeral_mounts_removed_reports_restore_error_after_operati events: list[str] = [] mount = _FakeMount(_FakeMountStrategy(events, name="mount", fail_restore=True)) session = _FakeSession(_FakeManifest([(mount, Path("/workspace/mount"))])) + operation_error = WorkspaceArchiveReadError( + path=Path("/workspace"), + context={"reason": "persist_failed"}, + ) async def operation() -> bytes: events.append("operation") - raise WorkspaceArchiveReadError( - path=Path("/workspace"), - context={"reason": "persist_failed"}, - ) + raise operation_error with pytest.raises(WorkspaceArchiveReadError) as exc_info: await with_ephemeral_mounts_removed( @@ -129,6 +130,6 @@ async def operation() -> bytes: assert events == ["teardown:mount", "operation", "restore:mount"] assert exc_info.value.context["snapshot_error_before_remount_corruption"] == { - "message": "failed to read archive for path: /workspace", + "message": operation_error.message, } assert isinstance(exc_info.value.cause, RuntimeError)