Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
97 changes: 15 additions & 82 deletions src/agents/extensions/sandbox/cloudflare/sandbox.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
from ....sandbox.session.base_sandbox_session import BaseSandboxSession
from ....sandbox.session.dependencies import Dependencies
from ....sandbox.session.manager import Instrumentation
from ....sandbox.session.mount_lifecycle import with_ephemeral_mounts_removed
from ....sandbox.session.pty_types import (
PTY_PROCESSES_MAX,
PTY_PROCESSES_WARNING,
Expand Down Expand Up @@ -1204,91 +1205,23 @@ async def _hydrate_workspace_via_http(self, data: io.IOBase) -> None:

async def persist_workspace(self) -> io.IOBase:
root = self._workspace_root_path()
unmounted_mounts: list[tuple[Any, Path]] = []
unmount_error: WorkspaceArchiveReadError | None = None
for mount_entry, mount_path in self.state.manifest.ephemeral_mount_targets():
try:
await mount_entry.mount_strategy.teardown_for_snapshot(
mount_entry, self, mount_path
)
except Exception as e:
unmount_error = WorkspaceArchiveReadError(path=root, cause=e)
break
unmounted_mounts.append((mount_entry, mount_path))

snapshot_error: WorkspaceArchiveReadError | None = None
persisted: io.IOBase | None = None
if unmount_error is None:
try:
persisted = await self._persist_workspace_via_http()
except WorkspaceArchiveReadError as e:
snapshot_error = e

remount_error: WorkspaceArchiveReadError | None = None
for mount_entry, mount_path in reversed(unmounted_mounts):
try:
await mount_entry.mount_strategy.restore_after_snapshot(
mount_entry, self, mount_path
)
except Exception as e:
if remount_error is None:
remount_error = WorkspaceArchiveReadError(path=root, cause=e)

if remount_error is not None:
if snapshot_error is not None:
remount_error.context["snapshot_error_before_remount_corruption"] = {
"message": snapshot_error.message,
}
raise remount_error
if unmount_error is not None:
raise unmount_error
if snapshot_error is not None:
raise snapshot_error

assert persisted is not None
return persisted
return await with_ephemeral_mounts_removed(
self,
self._persist_workspace_via_http,
error_path=root,
error_cls=WorkspaceArchiveReadError,
operation_error_context_key="snapshot_error_before_remount_corruption",
)

async def hydrate_workspace(self, data: io.IOBase) -> None:
root = self._workspace_root_path()
unmounted_mounts: list[tuple[Any, Path]] = []
unmount_error: WorkspaceArchiveWriteError | None = None
for mount_entry, mount_path in self.state.manifest.ephemeral_mount_targets():
try:
await mount_entry.mount_strategy.teardown_for_snapshot(
mount_entry, self, mount_path
)
except Exception as e:
unmount_error = WorkspaceArchiveWriteError(path=root, cause=e)
break
unmounted_mounts.append((mount_entry, mount_path))

hydrate_error: WorkspaceArchiveWriteError | None = None
if unmount_error is None:
try:
await self._hydrate_workspace_via_http(data)
except WorkspaceArchiveWriteError as e:
hydrate_error = e

remount_error: WorkspaceArchiveWriteError | None = None
for mount_entry, mount_path in reversed(unmounted_mounts):
try:
await mount_entry.mount_strategy.restore_after_snapshot(
mount_entry, self, mount_path
)
except Exception as e:
if remount_error is None:
remount_error = WorkspaceArchiveWriteError(path=root, cause=e)

if remount_error is not None:
if hydrate_error is not None:
remount_error.context["hydrate_error_before_remount_corruption"] = {
"message": hydrate_error.message,
}
raise remount_error
if unmount_error is not None:
raise unmount_error
if hydrate_error is not None:
raise hydrate_error
await with_ephemeral_mounts_removed(
self,
lambda: self._hydrate_workspace_via_http(data),
error_path=root,
error_cls=WorkspaceArchiveWriteError,
operation_error_context_key="hydrate_error_before_remount_corruption",
)


class CloudflareSandboxClient(BaseSandboxClient[CloudflareSandboxClientOptions]):
Expand Down
112 changes: 14 additions & 98 deletions src/agents/extensions/sandbox/vercel/sandbox.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
import posixpath
import tarfile
import uuid
from collections.abc import Awaitable, Callable
from pathlib import Path, PurePosixPath
from typing import Any, Literal, cast
from urllib.parse import urlsplit
Expand Down Expand Up @@ -50,6 +49,7 @@
from ....sandbox.session.base_sandbox_session import BaseSandboxSession
from ....sandbox.session.dependencies import Dependencies
from ....sandbox.session.manager import Instrumentation
from ....sandbox.session.mount_lifecycle import with_ephemeral_mounts_removed
from ....sandbox.session.runtime_helpers import RESOLVE_WORKSPACE_PATH_HELPER, RuntimeHelperScript
from ....sandbox.session.sandbox_client import BaseSandboxClient, BaseSandboxClientOptions
from ....sandbox.snapshot import SnapshotBase, SnapshotSpec, resolve_snapshot
Expand Down Expand Up @@ -404,100 +404,6 @@ async def running(self) -> bool:
async def shutdown(self) -> None:
await self._stop_attached_sandbox()

async def _persist_with_ephemeral_mounts_removed(
self,
operation: Callable[[], Awaitable[io.IOBase]],
) -> io.IOBase:
root = self._workspace_root_path()
unmounted_mounts: list[tuple[Any, Path]] = []
unmount_error: WorkspaceArchiveReadError | None = None
for mount_entry, mount_path in self.state.manifest.ephemeral_mount_targets():
try:
await mount_entry.mount_strategy.teardown_for_snapshot(
mount_entry, self, mount_path
)
except Exception as exc:
unmount_error = WorkspaceArchiveReadError(path=root, cause=exc)
break
unmounted_mounts.append((mount_entry, mount_path))

persist_error: WorkspaceArchiveReadError | None = None
persisted: io.IOBase | None = None
if unmount_error is None:
try:
persisted = await operation()
except WorkspaceArchiveReadError as exc:
persist_error = exc

remount_error: WorkspaceArchiveReadError | None = None
for mount_entry, mount_path in reversed(unmounted_mounts):
try:
await mount_entry.mount_strategy.restore_after_snapshot(
mount_entry, self, mount_path
)
except Exception as exc:
if remount_error is None:
remount_error = WorkspaceArchiveReadError(path=root, cause=exc)

if remount_error is not None:
if persist_error is not None:
remount_error.context["snapshot_error_before_remount_corruption"] = {
"message": persist_error.message
}
raise remount_error
if unmount_error is not None:
raise unmount_error
if persist_error is not None:
raise persist_error

assert persisted is not None
return persisted

async def _hydrate_with_ephemeral_mounts_removed(
self,
operation: Callable[[], Awaitable[None]],
) -> None:
root = self._workspace_root_path()
unmounted_mounts: list[tuple[Any, Path]] = []
unmount_error: WorkspaceArchiveWriteError | None = None
for mount_entry, mount_path in self.state.manifest.ephemeral_mount_targets():
try:
await mount_entry.mount_strategy.teardown_for_snapshot(
mount_entry, self, mount_path
)
except Exception as exc:
unmount_error = WorkspaceArchiveWriteError(path=root, cause=exc)
break
unmounted_mounts.append((mount_entry, mount_path))

hydrate_error: WorkspaceArchiveWriteError | None = None
if unmount_error is None:
try:
await operation()
except WorkspaceArchiveWriteError as exc:
hydrate_error = exc

remount_error: WorkspaceArchiveWriteError | None = None
for mount_entry, mount_path in reversed(unmounted_mounts):
try:
await mount_entry.mount_strategy.restore_after_snapshot(
mount_entry, self, mount_path
)
except Exception as exc:
if remount_error is None:
remount_error = WorkspaceArchiveWriteError(path=root, cause=exc)

if remount_error is not None:
if hydrate_error is not None:
remount_error.context["hydrate_error_before_remount_corruption"] = {
"message": hydrate_error.message
}
raise remount_error
if unmount_error is not None:
raise unmount_error
if hydrate_error is not None:
raise hydrate_error

async def _exec_internal(
self,
*command: str | Path,
Expand Down Expand Up @@ -601,7 +507,13 @@ async def write(
raise WorkspaceArchiveWriteError(path=normalized_path, cause=exc) from exc

async def persist_workspace(self) -> io.IOBase:
return await self._persist_with_ephemeral_mounts_removed(self._persist_workspace_internal)
return await with_ephemeral_mounts_removed(
self,
self._persist_workspace_internal,
error_path=self._workspace_root_path(),
error_cls=WorkspaceArchiveReadError,
operation_error_context_key="snapshot_error_before_remount_corruption",
)

async def _persist_workspace_internal(self) -> io.IOBase:
if self.state.workspace_persistence == _WORKSPACE_PERSISTENCE_SNAPSHOT:
Expand Down Expand Up @@ -665,8 +577,12 @@ async def hydrate_workspace(self, data: io.IOBase) -> None:
actual_type=type(raw).__name__,
)

await self._hydrate_with_ephemeral_mounts_removed(
lambda: self._hydrate_workspace_internal(bytes(raw))
await with_ephemeral_mounts_removed(
self,
lambda: self._hydrate_workspace_internal(bytes(raw)),
error_path=self._workspace_root_path(),
error_cls=WorkspaceArchiveWriteError,
operation_error_context_key="hydrate_error_before_remount_corruption",
)

async def _hydrate_workspace_internal(self, raw: bytes) -> None:
Expand Down
112 changes: 112 additions & 0 deletions src/agents/sandbox/session/mount_lifecycle.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
from __future__ import annotations

from collections.abc import Awaitable, Callable
from pathlib import Path
from typing import TYPE_CHECKING, TypeAlias, TypeVar, cast

from ..errors import (
WorkspaceArchiveReadError,
WorkspaceArchiveWriteError,
WorkspaceIOError,
)

if TYPE_CHECKING:
from ..entries import Mount
from .base_sandbox_session import BaseSandboxSession

ArchiveError: TypeAlias = WorkspaceArchiveReadError | WorkspaceArchiveWriteError
ArchiveErrorClass: TypeAlias = type[WorkspaceArchiveReadError] | type[WorkspaceArchiveWriteError]

_ResultT = TypeVar("_ResultT")
_MISSING = object()


async def with_ephemeral_mounts_removed(
session: BaseSandboxSession,
operation: Callable[[], Awaitable[_ResultT]],
*,
error_path: Path,
error_cls: ArchiveErrorClass,
operation_error_context_key: str | None,
) -> _ResultT:
detached_mounts: list[tuple[Mount, Path]] = []
detach_error: ArchiveError | None = None
for mount_entry, mount_path in session.state.manifest.ephemeral_mount_targets():
try:
await mount_entry.mount_strategy.teardown_for_snapshot(mount_entry, session, mount_path)
except Exception as exc:
detach_error = error_cls(path=error_path, cause=exc)
break
detached_mounts.append((mount_entry, mount_path))

operation_error: ArchiveError | None = None
operation_result: object = _MISSING
if detach_error is None:
try:
operation_result = await operation()
except WorkspaceIOError as exc:
if not isinstance(exc, error_cls):
raise
operation_error = cast(ArchiveError, exc)

restore_error = await restore_detached_mounts(
session,
detached_mounts,
error_path=error_path,
error_cls=error_cls,
)

if restore_error is not None:
if operation_error is not None and operation_error_context_key is not None:
restore_error.context[operation_error_context_key] = {
"message": operation_error.message
}
raise restore_error
if detach_error is not None:
raise detach_error
if operation_error is not None:
raise operation_error

assert operation_result is not _MISSING
return cast(_ResultT, operation_result)


async def restore_detached_mounts(
session: BaseSandboxSession,
detached_mounts: list[tuple[Mount, Path]],
*,
error_path: Path,
error_cls: ArchiveErrorClass,
) -> ArchiveError | None:
restore_error: ArchiveError | None = None
for mount_entry, mount_path in reversed(detached_mounts):
try:
await mount_entry.mount_strategy.restore_after_snapshot(
mount_entry, session, mount_path
)
except Exception as exc:
current_error = error_cls(path=error_path, cause=exc)
if restore_error is None:
restore_error = current_error
else:
additional_errors = restore_error.context.setdefault(
"additional_remount_errors", []
)
assert isinstance(additional_errors, list)
additional_errors.append(workspace_archive_error_summary(current_error))
return restore_error


def workspace_archive_error_summary(error: ArchiveError) -> dict[str, str]:
summary = {"message": error.message}
if error.cause is not None:
summary["cause_type"] = type(error.cause).__name__
summary["cause"] = str(error.cause)
return summary


__all__ = [
"restore_detached_mounts",
"with_ephemeral_mounts_removed",
"workspace_archive_error_summary",
]
Loading
Loading