From c8f96813a8451bec709fb87b95396c9641beb523 Mon Sep 17 00:00:00 2001 From: tnm Date: Tue, 17 Feb 2026 12:02:44 -0800 Subject: [PATCH 1/5] Harden providers, implement eager pooling, and verify latest SDKs --- pyproject.toml | 24 +++--- sandboxes/cli.py | 4 +- sandboxes/pool.py | 57 ++++++++++++-- sandboxes/providers/cloudflare.py | 40 ++++++++-- sandboxes/providers/daytona.py | 14 +++- sandboxes/providers/e2b.py | 59 +++++++++++--- tests/test_cli.py | 4 +- tests/test_cloudflare_provider.py | 90 ++++++++++++++++++++++ tests/test_daytona_provider_regressions.py | 49 ++++++++++++ tests/test_e2b_provider_regressions.py | 83 ++++++++++++++++++++ tests/test_pool.py | 45 +++++++++++ 11 files changed, 430 insertions(+), 39 deletions(-) create mode 100644 tests/test_daytona_provider_regressions.py create mode 100644 tests/test_e2b_provider_regressions.py diff --git a/pyproject.toml b/pyproject.toml index 5a353b9..6aab08a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -26,25 +26,25 @@ dependencies = [ "typing-extensions>=4.0.0", "click>=8.0.0", "tabulate>=0.9.0", - "modal>=1.1.4", - "e2b>=2.0.0", - "daytona>=0.103.0", - "hopx-ai>=0.3.0", + "modal>=1.3.3", + "e2b>=2.13.2", + "daytona>=0.143.0", + "hopx-ai>=0.5.0", "httpx>=0.27.0", ] [project.optional-dependencies] daytona = [ - "daytona==0.103.0", # Official Daytona SDK - latest stable version + "daytona==0.143.0", # Official Daytona SDK - latest stable version ] e2b = [ - "e2b>=2.0.0", # Regular E2B SDK for standard Linux sandboxes + "e2b>=2.13.2", # Regular E2B SDK for standard Linux sandboxes ] modal = [ - "modal==1.1.4", # Latest stable version + "modal==1.3.3", # Latest stable version ] hopx = [ - "hopx-ai>=0.3.0", # Official Hopx SDK for secure cloud sandboxes + "hopx-ai>=0.5.0", # Official Hopx SDK for secure cloud sandboxes ] # vercel = [ # "vercel-sdk>=0.1.0", # When available @@ -53,10 +53,10 @@ hopx = [ # "cloudflare-workers-sdk>=0.1.0", # When available # ] all = [ - "daytona==0.103.0", - "e2b>=2.0.0", - "modal==1.1.4", - "hopx-ai>=0.3.0", + "daytona==0.143.0", + "e2b>=2.13.2", + "modal==1.3.3", + "hopx-ai>=0.5.0", ] dev = [ "pytest>=7.4.0", diff --git a/sandboxes/cli.py b/sandboxes/cli.py index d1de350..0d8d5b0 100644 --- a/sandboxes/cli.py +++ b/sandboxes/cli.py @@ -8,6 +8,8 @@ import click +from . import __version__ + def get_provider(name: str): """Get a provider instance by name.""" @@ -38,7 +40,7 @@ def get_provider(name: str): @click.group() -@click.version_option(version="0.2.3", prog_name="cased-sandboxes") +@click.version_option(version=__version__, prog_name="cased-sandboxes") def cli(): """Universal AI code execution sandboxes.""" pass diff --git a/sandboxes/pool.py b/sandboxes/pool.py index 625572a..e92e36a 100644 --- a/sandboxes/pool.py +++ b/sandboxes/pool.py @@ -93,6 +93,11 @@ def __init__(self, pool_config: PoolConfig | None = None): # Locks for thread-safe operations self._lock = asyncio.Lock() self._condition = asyncio.Condition(self._lock) + self._ensure_min_idle_lock = asyncio.Lock() + + # Template used for eager prewarming + self._warm_provider: Any | None = None + self._warm_config: SandboxConfig | None = None # Cleanup task self._cleanup_task: asyncio.Task | None = None @@ -106,11 +111,15 @@ def __init__(self, pool_config: PoolConfig | None = None): "errors": 0, } - async def start(self): + async def start(self, provider: Any | None = None, config: SandboxConfig | None = None): """Start the pool and background tasks.""" if self.config.auto_cleanup: self._cleanup_task = asyncio.create_task(self._cleanup_loop()) + if provider and config: + self._warm_provider = provider + self._warm_config = config + # Pre-create sandboxes if using eager strategy if self.config.strategy == PoolStrategy.EAGER: await self._ensure_min_idle() @@ -149,6 +158,12 @@ async def acquire( if self.config.max_total <= 0: raise SandboxQuotaError("Pool limit reached: 0") + self._warm_provider = provider + self._warm_config = config + + if self.config.strategy == PoolStrategy.EAGER: + await self._ensure_min_idle(provider, config) + eviction_entry: SandboxPoolEntry | None = None try: @@ -223,6 +238,9 @@ async def release(self, sandbox_id: str): for entry in evictions: await self._finalize_eviction(entry) + if self.config.strategy == PoolStrategy.EAGER: + await self._ensure_min_idle() + async def destroy(self, sandbox_id: str): """ Destroy a sandbox and remove from pool. @@ -266,6 +284,7 @@ async def _create_sandbox(self, provider: Any, config: SandboxConfig) -> Sandbox # Add to pool self._pool[sandbox.id] = entry + self._idle_sandboxes.add(sandbox.id) # Update label index for key, value in entry.labels.items(): @@ -354,11 +373,36 @@ async def _evict_idle_sandbox(self) -> bool: await self._finalize_eviction(entry) return True - async def _ensure_min_idle(self): + async def _ensure_min_idle( + self, provider: Any | None = None, config: SandboxConfig | None = None + ) -> None: """Ensure minimum idle sandboxes (for eager strategy).""" - # This would need provider and config information - # Implement based on specific requirements - pass + async with self._ensure_min_idle_lock: + if provider and config: + self._warm_provider = provider + self._warm_config = config + + provider_to_use = provider or self._warm_provider + config_to_use = config or self._warm_config + if provider_to_use is None or config_to_use is None: + return + + target_idle = min(self.config.min_idle, self.config.max_idle, self.config.max_total) + if target_idle <= 0: + return + + while True: + async with self._lock: + idle_count = len(self._idle_sandboxes) + total_count = len(self._pool) + if idle_count >= target_idle or total_count >= self.config.max_total: + return + try: + await self._create_sandbox(provider_to_use, config_to_use) + except Exception as e: + self._stats["errors"] += 1 + logger.error(f"Failed to pre-create idle sandbox: {e}") + return async def _remove_from_pool(self, sandbox_id: str): """Remove sandbox from pool and indexes.""" @@ -450,6 +494,9 @@ async def _cleanup_expired(self): logger.info(f"Cleaning up expired sandbox {sandbox_id}") await self.destroy(sandbox_id) + if self.config.strategy == PoolStrategy.EAGER: + await self._ensure_min_idle() + async def _call_hook(self, hook: Callable, *args, **kwargs): """Call a lifecycle hook safely.""" try: diff --git a/sandboxes/providers/cloudflare.py b/sandboxes/providers/cloudflare.py index c4cf626..22797e8 100644 --- a/sandboxes/providers/cloudflare.py +++ b/sandboxes/providers/cloudflare.py @@ -5,6 +5,9 @@ import asyncio import base64 import json +import re +import shlex +import time import uuid from collections.abc import AsyncIterator from contextlib import suppress @@ -17,6 +20,7 @@ from ..security import validate_download_path, validate_upload_path _DEFAULT_TIMEOUT = 30.0 +_ENV_VAR_NAME_RE = re.compile(r"^[A-Za-z_][A-Za-z0-9_]*$") class CloudflareProvider(SandboxProvider): @@ -47,6 +51,7 @@ def __init__( self.account_id = account_id self._transport = transport self._user_agent = "cased-sandboxes/0.4.2" + self._last_accessed: dict[str, float] = {} @property def name(self) -> str: @@ -73,6 +78,7 @@ async def create_sandbox(self, config: SandboxConfig) -> Sandbox: "created_via": "cloudflare", }, ) + self._touch_session(session_id) return sandbox async def get_sandbox(self, sandbox_id: str) -> Sandbox | None: @@ -133,6 +139,7 @@ async def destroy_sandbox(self, sandbox_id: str) -> bool: ) except SandboxNotFoundError: return False + self._last_accessed.pop(sandbox_id, None) return True async def stream_execution( @@ -264,10 +271,11 @@ async def upload_file( # Create directory if needed dir_path = "/".join(remote_path.split("/")[:-1]) if dir_path: - await self.execute_command(sandbox_id, f"mkdir -p {dir_path}") + await self.execute_command(sandbox_id, f"mkdir -p {shlex.quote(dir_path)}") # Write file using base64 decode result = await self.execute_command( - sandbox_id, f"echo '{encoded}' | base64 -d > {remote_path}" + sandbox_id, + f"echo {shlex.quote(encoded)} | base64 -d > {shlex.quote(remote_path)}", ) return result.success @@ -295,7 +303,7 @@ async def download_file( return True except (SandboxError, SandboxNotFoundError): # Fallback: use cat and base64 encoding to read file - result = await self.execute_command(sandbox_id, f"cat {remote_path} | base64") + result = await self.execute_command(sandbox_id, f"cat {shlex.quote(remote_path)} | base64") if not result.success: return False @@ -315,11 +323,14 @@ async def cleanup_idle_sandboxes(self, idle_timeout: int = 600) -> None: cleans up our tracking. Actual sandbox cleanup happens automatically. """ sandboxes = await self.list_sandboxes() - asyncio.get_event_loop().time() + now = time.time() for sandbox in sandboxes: - # Since we don't track last access time in the Worker, - # we'll clean up all sandboxes as a precaution + last_accessed = self._last_accessed.get(sandbox.id) + if last_accessed is None: + continue + if now - last_accessed <= idle_timeout: + continue with suppress(SandboxNotFoundError): await self.destroy_sandbox(sandbox.id) @@ -360,13 +371,28 @@ def _apply_env_vars_to_command( ) -> str: if not env_vars: return command - exports = " && ".join([f"export {key}='{value}'" for key, value in env_vars.items()]) + exports = " && ".join( + [ + f"export {CloudflareProvider._validate_env_var_name(key)}={shlex.quote(str(value))}" + for key, value in env_vars.items() + ] + ) return f"{exports} && {command}" + @staticmethod + def _validate_env_var_name(key: str) -> str: + if not _ENV_VAR_NAME_RE.match(key): + raise SandboxError(f"Invalid environment variable name: {key}") + return key + async def _ensure_session_exists(self, sandbox_id: str) -> None: sandbox = await self.get_sandbox(sandbox_id) if not sandbox: raise SandboxNotFoundError(f"Session {sandbox_id} not found") + self._touch_session(sandbox_id) + + def _touch_session(self, sandbox_id: str) -> None: + self._last_accessed[sandbox_id] = time.time() async def _request( self, diff --git a/sandboxes/providers/daytona.py b/sandboxes/providers/daytona.py index a58f6c3..665f819 100644 --- a/sandboxes/providers/daytona.py +++ b/sandboxes/providers/daytona.py @@ -1,6 +1,7 @@ """Daytona sandbox provider implementation.""" import logging +import math import os from typing import Any @@ -114,15 +115,24 @@ async def create_sandbox(self, config: SandboxConfig) -> Sandbox: # Configure resources if specified resources = None if config.memory_mb or config.cpu_cores: + memory_gib = None + if config.memory_mb: + # Daytona resources.memory is GiB; round up to avoid 0 GiB allocations. + memory_gib = max(1, math.ceil(config.memory_mb / 1024)) resources = Resources( cpu=int(config.cpu_cores) if config.cpu_cores else None, - memory=int(config.memory_mb / 1024) if config.memory_mb else None, + memory=memory_gib, ) logger.info( f"Configuring resources: CPU={config.cpu_cores}, Memory={config.memory_mb}MB" ) - params = CreateSandboxFromImageParams(image=image, resources=resources) + params = CreateSandboxFromImageParams( + image=image, + resources=resources, + labels=config.labels or {}, + env_vars=config.env_vars or {}, + ) else: # Use language-based creation (fallback) language = ( diff --git a/sandboxes/providers/e2b.py b/sandboxes/providers/e2b.py index 96f26ff..246904e 100644 --- a/sandboxes/providers/e2b.py +++ b/sandboxes/providers/e2b.py @@ -62,12 +62,26 @@ def name(self) -> str: async def _create_e2b_sandbox(self, template_id=None, env_vars=None, timeout=None): """Create E2B sandbox asynchronously.""" # timeout sets the sandbox lifetime in seconds - return await E2BSandbox.create( - template=template_id, - envs=env_vars, - api_key=self.api_key, - timeout=timeout or self.timeout, - ) + self._reset_e2b_transport_singleton() + try: + return await E2BSandbox.create( + template=template_id, + envs=env_vars, + api_key=self.api_key, + timeout=timeout or self.timeout, + ) + except RuntimeError as e: + # e2b>=2.13 uses a process-wide singleton async transport which can be bound + # to a closed event loop across test loops; reset and retry once. + if "Event loop is closed" not in str(e): + raise + self._reset_e2b_transport_singleton() + return await E2BSandbox.create( + template=template_id, + envs=env_vars, + api_key=self.api_key, + timeout=timeout or self.timeout, + ) def _to_sandbox(self, e2b_sandbox, metadata: dict[str, Any]) -> Sandbox: """Convert E2B sandbox to standard Sandbox.""" @@ -139,7 +153,12 @@ async def list_sandboxes(self, labels: dict[str, str] | None = None) -> list[San # First, get all sandboxes from E2B API try: # E2B's list() can return either a coroutine or AsyncSandboxPaginator depending on version - result = E2BSandbox.list() + self._reset_e2b_transport_singleton() + try: + result = E2BSandbox.list(api_key=self.api_key) + except TypeError: + # Older SDK versions don't accept api_key in list() + result = E2BSandbox.list() # Handle different return types if hasattr(result, "next_items"): @@ -183,6 +202,7 @@ async def list_sandboxes(self, labels: dict[str, str] | None = None) -> list[San "template": listed_sandbox.template_id, "name": listed_sandbox.name, "end_at": listed_sandbox.end_at, + "last_accessed": metadata.get("last_accessed"), }, ) ) @@ -202,9 +222,14 @@ async def find_sandbox(self, labels: dict[str, str]) -> Sandbox | None: """Find a running sandbox with matching labels for reuse.""" sandboxes = await self.list_sandboxes(labels=labels) if sandboxes: - # Return most recently accessed - sandboxes.sort(key=lambda s: self._sandboxes[s.id]["last_accessed"], reverse=True) - logger.info(f"Found existing sandbox {sandboxes[0].id} with labels {labels}") + # Prefer tracked sandboxes so we can reuse recency metadata safely. + tracked = [s for s in sandboxes if s.id in self._sandboxes] + if tracked: + tracked.sort(key=lambda s: self._sandboxes[s.id]["last_accessed"], reverse=True) + logger.info(f"Found existing sandbox {tracked[0].id} with labels {labels}") + return tracked[0] + + logger.info(f"Found API-listed sandbox {sandboxes[0].id} with labels {labels}") return sandboxes[0] return None @@ -457,6 +482,7 @@ async def destroy_sandbox(self, sandbox_id: str) -> bool: e2b_sandbox = metadata["e2b_sandbox"] else: # Try to connect to it via API + self._reset_e2b_transport_singleton() e2b_sandbox = await E2BSandbox.connect(sandbox_id) # Kill sandbox asynchronously @@ -538,3 +564,16 @@ def __del__(self): # Shutdown thread pool if hasattr(self, "_executor"): self._executor.shutdown(wait=False) + + @staticmethod + def _reset_e2b_transport_singleton() -> None: + """Reset e2b async transport singleton to avoid closed-loop reuse across test loops.""" + try: + from e2b.api import client_async as e2b_client_async + + transport_cls = getattr(e2b_client_async, "AsyncTransportWithLogger", None) + if transport_cls is not None: + transport_cls.singleton = None + except Exception: + # Best-effort reset for compatibility across SDK versions. + pass diff --git a/tests/test_cli.py b/tests/test_cli.py index 93234b5..4cf386c 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -6,7 +6,7 @@ import pytest from click.testing import CliRunner -from sandboxes import ExecutionResult, SandboxConfig +from sandboxes import ExecutionResult, SandboxConfig, __version__ from sandboxes.base import Sandbox, SandboxState from sandboxes.cli import cli, get_provider @@ -65,7 +65,7 @@ def test_cli_version(self): """Test CLI version command.""" result = self.runner.invoke(cli, ["--version"]) assert result.exit_code == 0 - assert "0.2.3" in result.output + assert __version__ in result.output @patch("sandboxes.cli.asyncio.run") def test_run_command_basic(self, mock_async_run): diff --git a/tests/test_cloudflare_provider.py b/tests/test_cloudflare_provider.py index 3615b08..a88cd02 100644 --- a/tests/test_cloudflare_provider.py +++ b/tests/test_cloudflare_provider.py @@ -2,7 +2,9 @@ import json import os +import shlex import tempfile +import time from pathlib import Path import httpx @@ -429,3 +431,91 @@ async def test_upload_symlink_escape_rejected(self, mock_provider, tmp_path): ) # Without explicit allowed_dirs restriction, symlinks are followed assert result is True + + +class TestCloudflareCommandSanitization: + """Security tests for shell command construction.""" + + def test_apply_env_vars_rejects_invalid_key(self): + with pytest.raises(SandboxError, match="Invalid environment variable name"): + CloudflareProvider._apply_env_vars_to_command("echo ok", {"BAD-KEY": "value"}) + + @pytest.mark.asyncio + async def test_fallback_file_commands_quote_remote_path(self, tmp_path): + remote_path = "/workspace/evil;touch-pwn.txt" + quoted_remote_path = shlex.quote(remote_path) + observed_commands: list[str] = [] + + def handler(request: httpx.Request) -> httpx.Response: + if request.url.path == "/api/session/list": + return httpx.Response(200, json={"sessions": ["quote-test"], "count": 1}) + if request.url.path in {"/api/file/write", "/api/file/read"}: + return httpx.Response(404) + if request.url.path == "/api/execute": + payload = json.loads(request.content.decode()) + command = payload["command"] + observed_commands.append(command) + if command.startswith("cat "): + import base64 + + return httpx.Response( + 200, + json={ + "stdout": base64.b64encode(b"ok").decode("utf-8"), + "stderr": "", + "exitCode": 0, + "success": True, + }, + ) + return httpx.Response( + 200, json={"stdout": "", "stderr": "", "exitCode": 0, "success": True} + ) + return httpx.Response(404) + + provider = CloudflareProvider( + base_url="https://sandbox.example.workers.dev", + transport=httpx.MockTransport(handler), + ) + + upload_path = tmp_path / "upload.txt" + upload_path.write_text("content") + download_path = tmp_path / "download.txt" + + upload_success = await provider.upload_file("quote-test", str(upload_path), remote_path) + download_success = await provider.download_file("quote-test", remote_path, str(download_path)) + + assert upload_success is True + assert download_success is True + assert download_path.read_bytes() == b"ok" + + assert any(f"> {quoted_remote_path}" in command for command in observed_commands) + assert any(f"cat {quoted_remote_path} | base64" == command for command in observed_commands) + assert all(f"> {remote_path}" not in command for command in observed_commands) + assert all(f"cat {remote_path} | base64" != command for command in observed_commands) + + +@pytest.mark.asyncio +async def test_cloudflare_cleanup_idle_respects_idle_timeout(): + deleted_sessions: list[str] = [] + + def handler(request: httpx.Request) -> httpx.Response: + if request.url.path == "/api/session/list": + return httpx.Response(200, json={"sessions": ["old", "fresh"], "count": 2}) + if request.url.path == "/api/process/kill-all": + deleted_sessions.append(request.url.params.get("session", "")) + return httpx.Response(200, json={"success": True}) + return httpx.Response(404) + + provider = CloudflareProvider( + base_url="https://sandbox.example.workers.dev", + transport=httpx.MockTransport(handler), + ) + now = time.time() + provider._last_accessed = { # noqa: SLF001 - intentional test probe + "old": now - 1200, + "fresh": now - 10, + } + + await provider.cleanup_idle_sandboxes(idle_timeout=600) + + assert deleted_sessions == ["old"] diff --git a/tests/test_daytona_provider_regressions.py b/tests/test_daytona_provider_regressions.py new file mode 100644 index 0000000..c36a2b3 --- /dev/null +++ b/tests/test_daytona_provider_regressions.py @@ -0,0 +1,49 @@ +"""Regression tests for Daytona provider behaviors.""" + +import pytest + +import sandboxes.providers.daytona as daytona_module +from sandboxes.base import SandboxConfig +from sandboxes.providers.daytona import DaytonaProvider + + +@pytest.mark.asyncio +async def test_image_create_includes_labels_env_and_rounded_memory(monkeypatch): + captured: dict[str, object] = {} + + class _FakeClient: + def create(self, params, timeout=None): + captured["params"] = params + captured["timeout"] = timeout + return type( + "CreatedSandbox", + (), + { + "id": "sb-daytona-1", + "state": "running", + "labels": getattr(params, "labels", {}), + "created_at": None, + "snapshot": None, + }, + )() + + monkeypatch.setattr(daytona_module, "Daytona", lambda: _FakeClient()) + + provider = DaytonaProvider(api_key="test-key") + sandbox = await provider.create_sandbox( + SandboxConfig( + image="python:3.12", + cpu_cores=1, + memory_mb=512, + labels={"team": "platform"}, + env_vars={"HELLO": "world"}, + timeout_seconds=123, + ) + ) + + params = captured["params"] + assert sandbox.id == "sb-daytona-1" + assert params.labels == {"team": "platform"} + assert params.env_vars == {"HELLO": "world"} + assert params.resources.memory == 1 + assert captured["timeout"] == 123 diff --git a/tests/test_e2b_provider_regressions.py b/tests/test_e2b_provider_regressions.py new file mode 100644 index 0000000..9615985 --- /dev/null +++ b/tests/test_e2b_provider_regressions.py @@ -0,0 +1,83 @@ +"""Regression tests for E2B provider behaviors.""" + +from datetime import datetime + +import pytest + +import sandboxes.providers.e2b as e2b_module +from sandboxes.providers.e2b import E2BProvider + + +class _ListedSandbox: + def __init__(self, sandbox_id: str, metadata: dict[str, str]): + self.sandbox_id = sandbox_id + self.metadata = metadata + self.started_at = datetime.now() + self.state = "running" + self.template_id = "base" + self.name = sandbox_id + self.end_at = None + + +class _Paginator: + def __init__(self, items): + self._items = items + + async def next_items(self): + return self._items + + +@pytest.mark.asyncio +async def test_find_sandbox_handles_api_listed_untracked_sandbox(monkeypatch): + class _FakeE2B: + @staticmethod + def list(api_key=None): # noqa: ARG004 + return _Paginator([_ListedSandbox("sb-untracked", {"env": "prod"})]) + + monkeypatch.setattr(e2b_module, "E2BSandbox", _FakeE2B) + + provider = E2BProvider(api_key="test-key") + sandbox = await provider.find_sandbox({"env": "prod"}) + + assert sandbox is not None + assert sandbox.id == "sb-untracked" + + +@pytest.mark.asyncio +async def test_list_sandboxes_supports_legacy_list_signature(monkeypatch): + class _FakeE2BLegacy: + @staticmethod + def list(): + return _Paginator([_ListedSandbox("sb-legacy", {"team": "infra"})]) + + monkeypatch.setattr(e2b_module, "E2BSandbox", _FakeE2BLegacy) + + provider = E2BProvider(api_key="test-key") + sandboxes = await provider.list_sandboxes(labels={"team": "infra"}) + + assert len(sandboxes) == 1 + assert sandboxes[0].id == "sb-legacy" + + +@pytest.mark.asyncio +async def test_create_retries_when_e2b_transport_bound_to_closed_loop(monkeypatch): + calls = {"count": 0} + + class _FakeSandbox: + sandbox_id = "sb-retry" + + class _FakeE2B: + @staticmethod + async def create(template=None, envs=None, api_key=None, timeout=None): # noqa: ARG004 + calls["count"] += 1 + if calls["count"] == 1: + raise RuntimeError("Event loop is closed") + return _FakeSandbox() + + monkeypatch.setattr(e2b_module, "E2BSandbox", _FakeE2B) + + provider = E2BProvider(api_key="test-key") + sandbox = await provider._create_e2b_sandbox(template_id="base", env_vars={}) + + assert sandbox.sandbox_id == "sb-retry" + assert calls["count"] == 2 diff --git a/tests/test_pool.py b/tests/test_pool.py index b689736..dfc41e1 100644 --- a/tests/test_pool.py +++ b/tests/test_pool.py @@ -306,6 +306,51 @@ async def test_health_check(self, pool): assert len(unhealthy) == 1 assert unhealthy[0] == sandbox_id + @pytest.mark.asyncio + async def test_eager_strategy_prewarms_idle_sandboxes(self): + """Eager strategy should keep the configured minimum number of idle sandboxes.""" + pool = SandboxPool( + PoolConfig( + min_idle=2, + max_total=5, + max_idle=5, + strategy=PoolStrategy.EAGER, + auto_cleanup=False, + ) + ) + provider = MockProvider() + config = SandboxConfig(labels={"test": "eager"}) + + sandbox = await pool.acquire(provider, config) + assert sandbox is not None + + stats = pool.get_stats() + assert stats["total"] >= 2 + assert stats["idle"] >= 1 + assert stats["busy"] >= 1 + + @pytest.mark.asyncio + async def test_start_with_template_prewarms_idle(self): + """start(provider, config) should pre-create idle sandboxes for eager pools.""" + pool = SandboxPool( + PoolConfig( + min_idle=2, + max_total=5, + max_idle=5, + strategy=PoolStrategy.EAGER, + auto_cleanup=False, + ) + ) + provider = MockProvider() + config = SandboxConfig(labels={"test": "start-eager"}) + + await pool.start(provider, config) + stats = pool.get_stats() + + assert stats["idle"] == 2 + assert stats["busy"] == 0 + assert stats["total"] == 2 + class TestConnectionPool: """Test the ConnectionPool class specifically.""" From 26cf1a23d2ba4111a60ea72001ed6e3fd9dde4bb Mon Sep 17 00:00:00 2001 From: tnm Date: Tue, 17 Feb 2026 12:13:59 -0800 Subject: [PATCH 2/5] Use async-native Modal SDK APIs in provider --- sandboxes/providers/modal.py | 101 +++++++---------------------------- 1 file changed, 19 insertions(+), 82 deletions(-) diff --git a/sandboxes/providers/modal.py b/sandboxes/providers/modal.py index a7943c8..763cb56 100644 --- a/sandboxes/providers/modal.py +++ b/sandboxes/providers/modal.py @@ -4,7 +4,6 @@ import logging import time from collections.abc import AsyncIterator -from concurrent.futures import ThreadPoolExecutor from datetime import datetime from typing import Any @@ -48,11 +47,6 @@ def __init__(self, **config): self.default_cpu = config.get("cpu", 2.0) self.default_memory = config.get("memory", 2048) self.timeout = config.get("timeout", 300) - self.max_workers = config.get("max_workers", 5) - - # Thread pool for blocking SDK calls - self._executor = ThreadPoolExecutor(max_workers=self.max_workers) - # Track active sandboxes with metadata self._sandboxes: dict[str, dict[str, Any]] = {} @@ -64,28 +58,6 @@ def name(self) -> str: """Provider name.""" return "modal" - def _create_modal_sandbox(self, image: str | Any, cpu: float, memory: int, timeout: int): - """Create Modal sandbox synchronously. - - Args: - image: Either a string (Docker registry image) or a modal.Image object - cpu: CPU cores to allocate - memory: Memory in MB - timeout: Timeout in seconds - """ - # Modal sandboxes require an App context - # Use a persistent app that creates itself if missing - app = modal.App.lookup("sandboxes-provider", create_if_missing=True) - - # Handle both string images and modal.Image objects - modal_image = modal.Image.from_registry(image) if isinstance(image, str) else image - - # Create Modal sandbox with specified resources - sandbox = ModalSandbox.create( - app=app, image=modal_image, cpu=cpu, memory=memory, timeout=timeout - ) - return sandbox - def _to_sandbox(self, modal_sandbox: ModalSandbox, metadata: dict[str, Any]) -> Sandbox: """Convert Modal sandbox to standard Sandbox.""" return Sandbox( @@ -125,10 +97,14 @@ async def create_sandbox(self, config: SandboxConfig) -> Sandbox: ) timeout = config.timeout_seconds or self.timeout - # Create sandbox in thread pool - loop = asyncio.get_event_loop() - modal_sandbox = await loop.run_in_executor( - self._executor, self._create_modal_sandbox, image, cpu, memory, timeout + app = await modal.App.lookup.aio("sandboxes-provider", create_if_missing=True) + modal_image = modal.Image.from_registry(image) if isinstance(image, str) else image + modal_sandbox = await ModalSandbox.create.aio( + app=app, + image=modal_image, + cpu=cpu, + memory=memory, + timeout=timeout, ) # Store metadata - include env_vars for use in each command @@ -169,10 +145,7 @@ async def get_sandbox(self, sandbox_id: str) -> Sandbox | None: # Try to fetch from Modal API try: - loop = asyncio.get_event_loop() - modal_sandbox = await loop.run_in_executor( - self._executor, lambda: ModalSandbox.from_id(sandbox_id) - ) + modal_sandbox = await ModalSandbox.from_id.aio(sandbox_id) # Create metadata for found sandbox metadata = { @@ -208,8 +181,7 @@ async def list_sandboxes(self, labels: dict[str, str] | None = None) -> list[San # Also try to list from Modal API try: - # Modal's list() is a sync generator - modal_sandboxes = list(ModalSandbox.list()) + modal_sandboxes = [s async for s in ModalSandbox.list.aio()] for modal_sandbox in modal_sandboxes: if modal_sandbox.object_id not in self._sandboxes: @@ -289,38 +261,16 @@ def validate_env_key(key: str) -> str: ) command = f"{env_setup} && {command}" - # Execute command in thread pool - loop = asyncio.get_event_loop() start_time = time.time() # Modal's exec returns a process object - # Use 'sh' instead of 'bash' for alpine compatibility - process = await loop.run_in_executor( - self._executor, - lambda: modal_sandbox.exec("sh", "-c", command, timeout=timeout or self.timeout), - ) - - # Wait for completion first - Modal SDK may require this before reading - # wait() might be sync or async - wait_result = process.wait() - exit_code = await wait_result if asyncio.iscoroutine(wait_result) else wait_result - - # Get output - Modal's read() may be sync or async depending on version - if process.stdout: - stdout_result = process.stdout.read() - stdout = ( - await stdout_result if asyncio.iscoroutine(stdout_result) else stdout_result - ) - else: - stdout = "" + # Use 'sh' instead of 'bash' for alpine compatibility. + process = await modal_sandbox.exec.aio("sh", "-c", command, timeout=timeout or self.timeout) - if process.stderr: - stderr_result = process.stderr.read() - stderr = ( - await stderr_result if asyncio.iscoroutine(stderr_result) else stderr_result - ) - else: - stderr = "" + # Wait for completion before reading process output. + exit_code = await process.wait.aio() + stdout = await process.stdout.read.aio() if process.stdout else "" + stderr = await process.stderr.read.aio() if process.stderr else "" duration_ms = int((time.time() - start_time) * 1000) @@ -364,13 +314,8 @@ async def destroy_sandbox(self, sandbox_id: str) -> bool: if sandbox_id not in self._sandboxes: # Try to fetch from API try: - loop = asyncio.get_event_loop() - modal_sandbox = await loop.run_in_executor( - self._executor, lambda: ModalSandbox.from_id(sandbox_id) - ) - - # Terminate it - await loop.run_in_executor(self._executor, lambda: modal_sandbox.terminate()) + modal_sandbox = await ModalSandbox.from_id.aio(sandbox_id) + await modal_sandbox.terminate.aio() return True except Exception: return False @@ -379,9 +324,7 @@ async def destroy_sandbox(self, sandbox_id: str) -> bool: metadata = self._sandboxes[sandbox_id] modal_sandbox = metadata["modal_sandbox"] - # Terminate sandbox in thread pool - loop = asyncio.get_event_loop() - await loop.run_in_executor(self._executor, lambda: modal_sandbox.terminate()) + await modal_sandbox.terminate.aio() # Remove from tracking async with self._lock: @@ -452,9 +395,3 @@ async def cleanup_idle_sandboxes(self, idle_timeout: int = 600): for sandbox_id in to_destroy: logger.info(f"Cleaning up idle sandbox {sandbox_id}") await self.destroy_sandbox(sandbox_id) - - def __del__(self): - """Cleanup on deletion.""" - # Shutdown thread pool - if hasattr(self, "_executor"): - self._executor.shutdown(wait=False) From 19007509132ca08962a716773d0d1ca601ee2f32 Mon Sep 17 00:00:00 2001 From: tnm Date: Tue, 17 Feb 2026 12:20:23 -0800 Subject: [PATCH 3/5] Apply black formatting for CI lint --- sandboxes/providers/cloudflare.py | 4 +++- sandboxes/providers/modal.py | 4 +++- tests/test_cloudflare_provider.py | 4 +++- 3 files changed, 9 insertions(+), 3 deletions(-) diff --git a/sandboxes/providers/cloudflare.py b/sandboxes/providers/cloudflare.py index 22797e8..5f2bdb4 100644 --- a/sandboxes/providers/cloudflare.py +++ b/sandboxes/providers/cloudflare.py @@ -303,7 +303,9 @@ async def download_file( return True except (SandboxError, SandboxNotFoundError): # Fallback: use cat and base64 encoding to read file - result = await self.execute_command(sandbox_id, f"cat {shlex.quote(remote_path)} | base64") + result = await self.execute_command( + sandbox_id, f"cat {shlex.quote(remote_path)} | base64" + ) if not result.success: return False diff --git a/sandboxes/providers/modal.py b/sandboxes/providers/modal.py index 763cb56..7a77fc5 100644 --- a/sandboxes/providers/modal.py +++ b/sandboxes/providers/modal.py @@ -265,7 +265,9 @@ def validate_env_key(key: str) -> str: # Modal's exec returns a process object # Use 'sh' instead of 'bash' for alpine compatibility. - process = await modal_sandbox.exec.aio("sh", "-c", command, timeout=timeout or self.timeout) + process = await modal_sandbox.exec.aio( + "sh", "-c", command, timeout=timeout or self.timeout + ) # Wait for completion before reading process output. exit_code = await process.wait.aio() diff --git a/tests/test_cloudflare_provider.py b/tests/test_cloudflare_provider.py index a88cd02..4365646 100644 --- a/tests/test_cloudflare_provider.py +++ b/tests/test_cloudflare_provider.py @@ -482,7 +482,9 @@ def handler(request: httpx.Request) -> httpx.Response: download_path = tmp_path / "download.txt" upload_success = await provider.upload_file("quote-test", str(upload_path), remote_path) - download_success = await provider.download_file("quote-test", remote_path, str(download_path)) + download_success = await provider.download_file( + "quote-test", remote_path, str(download_path) + ) assert upload_success is True assert download_success is True From ecc990c44464dbd5fa44f03239b103aabc27d432 Mon Sep 17 00:00:00 2001 From: tnm Date: Tue, 17 Feb 2026 12:23:33 -0800 Subject: [PATCH 4/5] fix: use StrEnum for circuit breaker state --- sandboxes/retry.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sandboxes/retry.py b/sandboxes/retry.py index 82f3a54..9c4d766 100644 --- a/sandboxes/retry.py +++ b/sandboxes/retry.py @@ -8,7 +8,7 @@ import random from collections.abc import Callable from dataclasses import dataclass -from enum import Enum +from enum import StrEnum from functools import wraps from typing import Any, TypeVar @@ -24,7 +24,7 @@ T = TypeVar("T") -class CircuitBreakerState(str, Enum): +class CircuitBreakerState(StrEnum): """Circuit breaker states.""" CLOSED = "closed" From faa7c272ee781e204b1c4e34a9a4449c3f448071 Mon Sep 17 00:00:00 2001 From: tnm Date: Tue, 17 Feb 2026 12:26:58 -0800 Subject: [PATCH 5/5] style: black-format benchmark scripts --- benchmarks/comprehensive_benchmark.py | 23 ++++++++++++++++------- benchmarks/run_all_benchmarks.py | 1 + 2 files changed, 17 insertions(+), 7 deletions(-) diff --git a/benchmarks/comprehensive_benchmark.py b/benchmarks/comprehensive_benchmark.py index 9cfc59a..f25d592 100644 --- a/benchmarks/comprehensive_benchmark.py +++ b/benchmarks/comprehensive_benchmark.py @@ -13,6 +13,7 @@ Based on ai-sandbox-benchmark (Apache 2.0 License) https://github.com/nibzard/ai-sandbox-benchmark """ + import asyncio import os import sys @@ -51,7 +52,8 @@ }, "prime_calculation": { "name": "Prime Calculation", - "command": """python3 -c " + "command": ( + """python3 -c " def is_prime(n): if n < 2: return False for i in range(2, int(n**0.5) + 1): @@ -61,13 +63,15 @@ def is_prime(n): primes = [n for n in range(2, 1000) if is_prime(n)] print(f'Found {len(primes)} primes') " -""", +""" + ), "runs": 5, "description": "CPU-bound computation", }, "file_io": { "name": "File I/O (1000 files)", - "command": """python3 -c " + "command": ( + """python3 -c " import os # Write 1000 small files for i in range(1000): @@ -82,25 +86,30 @@ def is_prime(n): print(f'Processed {total} bytes') " -""", +""" + ), "runs": 3, "description": "I/O performance test", }, "package_install": { "name": "pip install requests", - "command": "pip install -q requests && python3 -c 'import requests; print(f\"requests {requests.__version__}\")'", + "command": ( + "pip install -q requests && python3 -c 'import requests; print(f\"requests {requests.__version__}\")'" + ), "runs": 2, "description": "Package installation speed (requests already installed in standard image)", }, "numpy_fft": { "name": "NumPy FFT", - "command": """python3 -c " + "command": ( + """python3 -c " import numpy as np x = np.random.random(10000) result = np.fft.fft(x) print(f'FFT: {len(result)} points') " -""", +""" + ), "runs": 3, "description": "Numerical computation with pre-installed packages", }, diff --git a/benchmarks/run_all_benchmarks.py b/benchmarks/run_all_benchmarks.py index 20504e8..59af011 100644 --- a/benchmarks/run_all_benchmarks.py +++ b/benchmarks/run_all_benchmarks.py @@ -7,6 +7,7 @@ Outputs comprehensive results to benchmarks/results.txt """ + import subprocess import sys import time