diff --git a/docs/sdk/agent_tools.mdx b/docs/sdk/agent_tools.mdx index 4eb585e..859670a 100644 --- a/docs/sdk/agent_tools.mdx +++ b/docs/sdk/agent_tools.mdx @@ -21,6 +21,7 @@ Inheriting from this class provides: - Pydantic's declarative syntax for defining state (fields). - Automatic application of the `@configurable` decorator. - A `get_tools` method for discovering methods decorated with `@dreadnode.tool_method`. +- Support for async context management, with automatic re-entrancy handling. ### name diff --git a/docs/sdk/api.mdx b/docs/sdk/api.mdx index 7cf7f8f..a87cafb 100644 --- a/docs/sdk/api.mdx +++ b/docs/sdk/api.mdx @@ -172,6 +172,7 @@ def create_project( ```python create_workspace( name: str, + key: str, organization_id: UUID, description: str | None = None, ) -> Workspace @@ -198,6 +199,7 @@ Creates a new workspace. def create_workspace( self, name: str, + key: str, organization_id: UUID, description: str | None = None, ) -> Workspace: @@ -214,6 +216,7 @@ def create_workspace( payload = { "name": name, + "key": key, "description": description, "org_id": str(organization_id), } @@ -237,7 +240,7 @@ Deletes a specific workspace. * **`workspace_id`** (`str | UUID`) - –The workspace identifier. + –The workspace key. ```python @@ -246,7 +249,7 @@ def delete_workspace(self, workspace_id: str | UUID) -> None: Deletes a specific workspace. Args: - workspace_id (str | UUID): The workspace identifier. + workspace_id (str | UUID): The workspace key. """ self.request("DELETE", f"/workspaces/{workspace_id!s}") @@ -707,17 +710,15 @@ def get_github_access_token(self, repos: list[str]) -> GithubTokenResponse: ### get\_organization ```python -get_organization( - organization_id: str | UUID, -) -> Organization +get_organization(org_id_or_key: UUID | str) -> Organization ``` Retrieves details of a specific organization. **Parameters:** -* **`organization_id`** - (`str`) +* **`org_id_or_key`** + (`str | UUID`) –The organization identifier. **Returns:** @@ -727,17 +728,17 @@ Retrieves details of a specific organization. ```python -def get_organization(self, organization_id: str | UUID) -> Organization: +def get_organization(self, org_id_or_key: UUID | str) -> Organization: """ Retrieves details of a specific organization. Args: - organization_id (str): The organization identifier. + org_id_or_key (str | UUID): The organization identifier. Returns: Organization: The Organization object. """ - response = self.request("GET", f"/organizations/{organization_id!s}") + response = self.request("GET", f"/organizations/{org_id_or_key!s}") return Organization(**response.json()) ``` @@ -1098,7 +1099,8 @@ def get_user_data_credentials(self) -> UserDataCredentials: ```python get_workspace( - workspace_id: str | UUID, org_id: UUID | None = None + workspace_id_or_key: UUID | str, + org_id: UUID | None = None, ) -> Workspace ``` @@ -1106,8 +1108,8 @@ Retrieves details of a specific workspace. **Parameters:** -* **`workspace_id`** - (`str`) +* **`workspace_id_or_key`** + (`str | UUID`) –The workspace identifier. **Returns:** @@ -1117,12 +1119,14 @@ Retrieves details of a specific workspace. ```python -def get_workspace(self, workspace_id: str | UUID, org_id: UUID | None = None) -> Workspace: +def get_workspace( + self, workspace_id_or_key: UUID | str, org_id: UUID | None = None +) -> Workspace: """ Retrieves details of a specific workspace. Args: - workspace_id (str): The workspace identifier. + workspace_id_or_key (str | UUID): The workspace identifier. Returns: Workspace: The Workspace object. @@ -1130,7 +1134,7 @@ def get_workspace(self, workspace_id: str | UUID, org_id: UUID | None = None) -> params: dict[str, str] = {} if org_id: params = {"org_id": str(org_id)} - response = self.request("GET", f"/workspaces/{workspace_id!s}", params=params) + response = self.request("GET", f"/workspaces/{workspace_id_or_key!s}", params=params) return Workspace(**response.json()) ``` @@ -1793,6 +1797,14 @@ is_active: bool Is the organization active? +### key + +```python +key: str +``` + +URL-friendly identifier for the organization. + ### max\_members ```python @@ -1809,14 +1821,6 @@ name: str Name of the organization. -### slug - -```python -slug: str -``` - -URL-friendly slug for the organization. - ### updated\_at ```python @@ -2321,6 +2325,14 @@ is_default: bool Is the workspace the default one? +### key + +```python +key: str +``` + +Unique key for the workspace. + ### name ```python @@ -2353,14 +2365,6 @@ project_count: int | None Number of projects in the workspace. -### slug - -```python -slug: str -``` - -URL-friendly slug for the workspace. - ### updated\_at ```python diff --git a/docs/sdk/main.mdx b/docs/sdk/main.mdx index c8449f4..cd47fe0 100644 --- a/docs/sdk/main.mdx +++ b/docs/sdk/main.mdx @@ -173,10 +173,13 @@ in the following order: 1. Environment variables: 2. `DREADNODE_SERVER_URL` or `DREADNODE_SERVER` 3. `DREADNODE_API_TOKEN` or `DREADNODE_API_KEY` -4. Dreadnode profile (from `dreadnode login`) -5. Uses `profile` parameter if provided -6. Falls back to `DREADNODE_PROFILE` environment variable -7. Defaults to active profile +4. `DREADNODE_ORGANIZATION` +5. `DREADNODE_WORKSPACE` +6. `DREADNODE_PROJECT` +7. Dreadnode profile (from `dreadnode login`) +8. Uses `profile` parameter if provided +9. Falls back to `DREADNODE_PROFILE` environment variable +10. Defaults to active profile **Parameters:** @@ -214,7 +217,7 @@ in the following order: (`str | None`, default: `None` ) - –The default project name to associate all runs with. This can also be in the format `org/workspace/project`. + –The default project name to associate all runs with. This can also be in the format `org/workspace/project` using the keys. * **`service_name`** (`str | None`, default: `None` @@ -270,6 +273,10 @@ def configure( 1. Environment variables: - `DREADNODE_SERVER_URL` or `DREADNODE_SERVER` - `DREADNODE_API_TOKEN` or `DREADNODE_API_KEY` + - `DREADNODE_ORGANIZATION` + - `DREADNODE_WORKSPACE` + - `DREADNODE_PROJECT` + 2. Dreadnode profile (from `dreadnode login`) - Uses `profile` parameter if provided - Falls back to `DREADNODE_PROFILE` environment variable @@ -282,7 +289,7 @@ def configure( local_dir: The local directory to store data in. organization: The default organization name or ID to use. workspace: The default workspace name or ID to use. - project: The default project name to associate all runs with. This can also be in the format `org/workspace/project`. + project: The default project name to associate all runs with. This can also be in the format `org/workspace/project` using the keys. service_name: The service name to use for OpenTelemetry. service_version: The service version to use for OpenTelemetry. console: Log span information to the console (`DREADNODE_CONSOLE` or the default is True). @@ -342,19 +349,8 @@ def configure( self.local_dir = local_dir _org, _workspace, _project = self._extract_project_components(project) - self.organization = _org or organization or os.environ.get(ENV_ORGANIZATION) - with contextlib.suppress(ValueError): - self.organization = UUID( - str(self.organization) - ) # Now, it's a UUID if possible, else str (name/slug) - self.workspace = _workspace or workspace or os.environ.get(ENV_WORKSPACE) - with contextlib.suppress(ValueError): - self.workspace = UUID( - str(self.workspace) - ) # Now, it's a UUID if possible, else str (name/slug) - self.project = _project or project or os.environ.get(ENV_PROJECT) self.service_name = service_name diff --git a/docs/sdk/task.mdx b/docs/sdk/task.mdx index 0cf9fe3..91a9da8 100644 --- a/docs/sdk/task.mdx +++ b/docs/sdk/task.mdx @@ -683,6 +683,7 @@ async def run_always(self, *args: P.args, **kwargs: P.kwargs) -> TaskSpan[R]: # # Log the output + output_object_hash = None if log_output and ( not isinstance(self.log_inputs, Inherited) or seems_useful_to_serialize(output) ): @@ -691,13 +692,12 @@ async def run_always(self, *args: P.args, **kwargs: P.kwargs) -> TaskSpan[R]: # output, attributes={"auto": True}, ) - elif run is not None: # Link the output to the inputs - for input_object_hash in input_object_hashes: - run.link_objects(output_object_hash, input_object_hash) - - if create_run: - run.log_output("output", output, attributes={"auto": True}) + if run is not None: + for input_object_hash in input_object_hashes: + run.link_objects(output_object_hash, input_object_hash) + elif run is not None and create_run: + run.log_output("output", output, attributes={"auto": True}) # Score and check assertions diff --git a/dreadnode/agent/agent.py b/dreadnode/agent/agent.py index 6170bfa..a680b87 100644 --- a/dreadnode/agent/agent.py +++ b/dreadnode/agent/agent.py @@ -2,7 +2,7 @@ import json import re import typing as t -from contextlib import aclosing, asynccontextmanager +from contextlib import AsyncExitStack, aclosing, asynccontextmanager from copy import deepcopy from textwrap import dedent @@ -875,8 +875,16 @@ async def stream( commit: CommitBehavior = "always", ) -> t.AsyncIterator[t.AsyncGenerator[AgentEvent, None]]: thread = thread or self.thread - async with aclosing(self._stream_traced(thread, user_input, commit=commit)) as stream: - yield stream + + async with AsyncExitStack() as stack: + # Ensure all tools are properly entered if they + # are context managers before we start using them + for tool_container in self.tools: + if hasattr(tool_container, "__aenter__") and hasattr(tool_container, "__aexit__"): + await stack.enter_async_context(tool_container) + + async with aclosing(self._stream_traced(thread, user_input, commit=commit)) as stream: + yield stream async def run( self, diff --git a/dreadnode/agent/tools/base.py b/dreadnode/agent/tools/base.py index ac63164..1c82b87 100644 --- a/dreadnode/agent/tools/base.py +++ b/dreadnode/agent/tools/base.py @@ -1,6 +1,8 @@ +import asyncio +import functools import typing as t -from pydantic import ConfigDict +from pydantic import ConfigDict, PrivateAttr from rigging import tools from rigging.tools.base import ToolMethod as RiggingToolMethod @@ -171,18 +173,82 @@ class Toolset(Model): - Pydantic's declarative syntax for defining state (fields). - Automatic application of the `@configurable` decorator. - A `get_tools` method for discovering methods decorated with `@dreadnode.tool_method`. + - Support for async context management, with automatic re-entrancy handling. """ + model_config = ConfigDict(arbitrary_types_allowed=True, use_attribute_docstrings=True) + variant: str | None = None """The variant for filtering tools available in this toolset.""" - model_config = ConfigDict(arbitrary_types_allowed=True, use_attribute_docstrings=True) + # Context manager magic + _entry_ref_count: int = PrivateAttr(default=0) + _context_handle: object = PrivateAttr(default=None) + _entry_lock: asyncio.Lock = PrivateAttr(default_factory=asyncio.Lock) @property def name(self) -> str: """The name of the toolset, derived from the class name.""" return self.__class__.__name__ + def __init_subclass__(cls, **kwargs: t.Any) -> None: + super().__init_subclass__(**kwargs) + + # This essentially ensures that if the Toolset is any kind of context manager, + # it will be re-entrant, and only actually enter/exit once. This means we can + # safely build auto-entry/exit logic into our Agent class without worrying about + # breaking the code if the user happens to enter a toolset manually before using + # it in an agent. + + original_aenter = cls.__dict__.get("__aenter__") + original_enter = cls.__dict__.get("__enter__") + original_aexit = cls.__dict__.get("__aexit__") + original_exit = cls.__dict__.get("__exit__") + + has_enter = callable(original_aenter) or callable(original_enter) + has_exit = callable(original_aexit) or callable(original_exit) + + if has_enter and not has_exit: + raise TypeError( + f"{cls.__name__} defining __aenter__ or __enter__ must also define __aexit__ or __exit__" + ) + if has_exit and not has_enter: + raise TypeError( + f"{cls.__name__} defining __aexit__ or __exit__ must also define __aenter__ or __enter__" + ) + if original_aenter and original_enter: + raise TypeError(f"{cls.__name__} cannot define both __aenter__ and __enter__") + if original_aexit and original_exit: + raise TypeError(f"{cls.__name__} cannot define both __aexit__ and __exit__") + + @functools.wraps(original_aenter or original_enter) # type: ignore[arg-type] + async def aenter_wrapper(self: "Toolset", *args: t.Any, **kwargs: t.Any) -> t.Any: + async with self._entry_lock: + if self._entry_ref_count == 0: + handle = None + if original_aenter: + handle = await original_aenter(self, *args, **kwargs) + elif original_enter: + handle = original_enter(self, *args, **kwargs) + self._context_handle = handle if handle is not None else self + self._entry_ref_count += 1 + return self._context_handle + + cls.__aenter__ = aenter_wrapper # type: ignore[attr-defined] + + @functools.wraps(original_aexit or original_exit) # type: ignore[arg-type] + async def aexit_wrapper(self: "Toolset", *args: t.Any, **kwargs: t.Any) -> t.Any: + async with self._entry_lock: + self._entry_ref_count -= 1 + if self._entry_ref_count == 0: + if original_aexit: + await original_aexit(self, *args, **kwargs) + elif original_exit: + original_exit(self, *args, **kwargs) + self._context_handle = None + + cls.__aexit__ = aexit_wrapper # type: ignore[attr-defined] + def get_tools(self, *, variant: str | None = None) -> list[AnyTool]: variant = variant or self.variant diff --git a/dreadnode/api/models.py b/dreadnode/api/models.py index 194cb41..be011d6 100644 --- a/dreadnode/api/models.py +++ b/dreadnode/api/models.py @@ -502,7 +502,7 @@ class Organization(BaseModel): name: str """Name of the organization.""" key: str - """URL-friendly identifer for the organization.""" + """URL-friendly identifier for the organization.""" description: str | None """Description of the organization.""" is_active: bool diff --git a/tests/test_agent_lifecycle.py b/tests/test_agent_lifecycle.py new file mode 100644 index 0000000..eefa048 --- /dev/null +++ b/tests/test_agent_lifecycle.py @@ -0,0 +1,239 @@ +import asyncio +import inspect +import typing as t + +import pytest + +from dreadnode.agent import Agent +from dreadnode.agent.tools import Toolset, tool, tool_method + +if t.TYPE_CHECKING: + from dreadnode.agent.tools.base import AnyTool + +# This is the state tracker that will record the order of events. +event_log: list[str] = [] + + +class AsyncCMToolSet(Toolset): + """ + Scenario 1: A standard, async-native Toolset. + Tests that __aenter__/__aexit__ are called correctly and only once. + """ + + enter_count: int = 0 + exit_count: int = 0 + + async def __aenter__(self) -> "AsyncCMToolSet": + event_log.append("async_tool_enter_start") + await asyncio.sleep(0.01) # Simulate async work + self.enter_count += 1 + event_log.append("async_tool_enter_end") + return self + + async def __aexit__(self, *args: object) -> None: + event_log.append("async_tool_exit_start") + await asyncio.sleep(0.01) + self.exit_count += 1 + event_log.append("async_tool_exit_end") + + @tool_method + async def do_work(self) -> str: + """A sample method for the agent to call.""" + event_log.append("async_tool_method_called") + return "async work done" + + +class SyncCMToolSet(Toolset): + """ + Scenario 2: A Toolset using synchronous __enter__/__exit__. + Tests that our magic bridge correctly calls them in order. + """ + + def __enter__(self) -> "SyncCMToolSet": + event_log.append("sync_tool_enter") + return self + + def __exit__(self, *args: object) -> None: + event_log.append("sync_tool_exit") + + @tool_method + def do_blocking_work(self) -> str: + """A sample sync method.""" + event_log.append("sync_tool_method_called") + return "sync work done" + + +@tool +async def stateless_tool() -> str: + """ + Scenario 3: A simple, stateless tool. + Tests that the lifecycle manager ignores it. + """ + event_log.append("stateless_tool_called") + return "stateless work done" + + +class StandaloneCMToolset(Toolset): + """ + Scenario 4 (Revised): A Toolset that acts as a standalone context manager. + It must inherit from Toolset to be preserved by the Agent's validator. + """ + + async def __aenter__(self) -> "StandaloneCMToolset": + event_log.append("standalone_cm_enter") + return self + + async def __aexit__(self, *args: object) -> None: + event_log.append("standalone_cm_exit") + + @tool_method + async def do_standalone_work(self) -> str: + """The callable part of the tool.""" + event_log.append("standalone_cm_called") + return "standalone work done" + + +class ReturnValueToolSet(Toolset): + """ + Scenario 5: A Toolset whose __aenter__ returns a different object. + Tests that the `as` clause contract is honored. + """ + + class Handle: + def __init__(self, message: str) -> None: + self.message = message + + async def __aenter__(self) -> "ReturnValueToolSet.Handle": + event_log.append("return_value_tool_enter") + # Return a handle object, NOT self + return self.Handle("special handle") + + async def __aexit__(self, *args: object) -> None: + event_log.append("return_value_tool_exit") + + +# --- Mock Agent to Control Execution --- + + +class MockAgent(Agent): + """ + An agent override that doesn't call an LLM. Instead, it simulates + a run where it calls every available tool once. + """ + + async def _stream_traced( # type: ignore[override] + self, + thread: object, # noqa: ARG002 + user_input: str, # noqa: ARG002 + *, + commit: bool = True, # noqa: ARG002 + ) -> t.AsyncIterator[str]: + event_log.append("agent_run_start") + # Simulate calling each tool the agent knows about + for tool_ in self.all_tools: + result = tool_() + if inspect.isawaitable(result): + await result + event_log.append("agent_run_end") + # Yield a dummy event to satisfy the stream consumer + yield "dummy_event" + + +# --- The Tests --- + + +@pytest.mark.asyncio +async def test_agent_manages_all_lifecycle_scenarios() -> None: + """ + Main integration test. Verifies that the Agent correctly manages setup, + execution, and teardown for a mix of tool types in the correct order. + """ + event_log.clear() + + # 1. Setup our collection of tools + async_tool = AsyncCMToolSet() + sync_tool = SyncCMToolSet() + standalone_toolset = StandaloneCMToolset() + + # The list passed to the Agent contains the containers + agent_tools: list[AnyTool | Toolset] = [ + async_tool, + sync_tool, + stateless_tool, + standalone_toolset, + ] + + agent = MockAgent(name="test_agent", tools=agent_tools) + + # 2. Execute the agent run within its stream context + async with agent.stream("test input") as stream: + event_log.append("stream_context_active") + async for _ in stream: + pass # Consume the stream to trigger the run + + # 3. Assert the order of events + expected_order = [ + # Setup phase (order of entry is guaranteed by list order) + "async_tool_enter_start", + "async_tool_enter_end", + "sync_tool_enter", + "standalone_cm_enter", + # Agent execution phase + "stream_context_active", + "agent_run_start", + "agent_run_end", + # Teardown phase (must be LIFO) + "standalone_cm_exit", + "sync_tool_exit", + "async_tool_exit_start", + "async_tool_exit_end", + ] + + # Extract the tool call events to check for presence separately + run_events = [e for e in event_log if e.endswith("_called")] + actual_order_without_run_events = [e for e in event_log if not e.endswith("_called")] + + assert actual_order_without_run_events == expected_order + assert sorted(run_events) == sorted( + [ + "async_tool_method_called", + "sync_tool_method_called", + "stateless_tool_called", + "standalone_cm_called", + ] + ) + + # 4. Assert idempotency (enter/exit should have been called only once) + assert async_tool.enter_count == 1 + assert async_tool.exit_count == 1 + + +@pytest.mark.asyncio +async def test_toolset_idempotency_wrapper() -> None: + """ + A tight unit test to verify that our wrapper magic correctly + prevents a toolset from being entered more than once. + """ + tool = AsyncCMToolSet() + + # Nesting the context manager simulates the agent entering it + # after the user might have manually (and incorrectly) entered it. + async with tool as outer_handle: + assert tool.enter_count == 1 + async with tool as inner_handle: + assert tool.enter_count == 1 # Should NOT have increased + assert inner_handle is outer_handle # Should return the same handle + + assert tool.exit_count == 1 # Exit logic should only have run once + + +@pytest.mark.asyncio +async def test_toolset_return_value_is_honored() -> None: + """ + Verifies that the handle returned by a custom __aenter__ is preserved. + """ + tool = ReturnValueToolSet() + + async with tool as handle: + assert isinstance(handle, tool.Handle) + assert handle.message == "special handle"