From bc77db33dfba61de232cf72706cfea256bf59705 Mon Sep 17 00:00:00 2001 From: Calvin Giles Date: Fri, 23 May 2025 19:27:23 +1200 Subject: [PATCH 1/7] feat: Rich Approval Mechanism for Human in the loop Tool Approvals --- src/google/adk/approval/__init__.py | 16 + src/google/adk/approval/approval_grant.py | 59 ++ src/google/adk/approval/approval_handler.py | 499 +++++++++++ src/google/adk/approval/approval_policy.py | 255 ++++++ src/google/adk/approval/approval_request.py | 166 ++++ .../approval/approval_request_processor.py | 307 +++++++ src/google/adk/events/event_actions.py | 14 + .../adk/flows/llm_flows/base_llm_flow.py | 10 + src/google/adk/flows/llm_flows/contents.py | 36 +- src/google/adk/flows/llm_flows/functions.py | 109 ++- src/google/adk/flows/llm_flows/single_flow.py | 20 +- src/google/adk/tools/agent_tool.py | 53 ++ src/google/adk/tools/crewai_tool.py | 23 +- src/google/adk/tools/langchain_tool.py | 21 + src/google/adk/tools/tool_context.py | 38 +- tests/unittests/approval/__init__.py | 0 .../unittests/approval/test_approval_grant.py | 64 ++ .../approval/test_approval_handler.py | 803 ++++++++++++++++++ .../approval/test_approval_policy.py | 227 +++++ .../approval/test_approval_preprocessor.py | 140 +++ tests/unittests/approval/utils.py | 57 ++ tests/unittests/conftest.py | 18 + .../flows/llm_flows/test_approval.py | 799 +++++++++++++++++ .../llm_flows/test_async_tool_callbacks.py | 2 +- tests/unittests/testing_utils.py | 54 +- 25 files changed, 3773 insertions(+), 17 deletions(-) create mode 100644 src/google/adk/approval/__init__.py create mode 100644 src/google/adk/approval/approval_grant.py create mode 100644 src/google/adk/approval/approval_handler.py create mode 100644 src/google/adk/approval/approval_policy.py create mode 100644 src/google/adk/approval/approval_request.py create mode 100644 src/google/adk/approval/approval_request_processor.py create mode 100644 tests/unittests/approval/__init__.py create mode 100644 tests/unittests/approval/test_approval_grant.py create mode 100644 tests/unittests/approval/test_approval_handler.py create mode 100644 tests/unittests/approval/test_approval_policy.py create mode 100644 tests/unittests/approval/test_approval_preprocessor.py create mode 100644 tests/unittests/approval/utils.py create mode 100644 tests/unittests/flows/llm_flows/test_approval.py diff --git a/src/google/adk/approval/__init__.py b/src/google/adk/approval/__init__.py new file mode 100644 index 0000000000..6288cba286 --- /dev/null +++ b/src/google/adk/approval/__init__.py @@ -0,0 +1,16 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Provides foundational classes and handlers for managing approvals in the ADK.""" + diff --git a/src/google/adk/approval/approval_grant.py b/src/google/adk/approval/approval_grant.py new file mode 100644 index 0000000000..bc2f9892aa --- /dev/null +++ b/src/google/adk/approval/approval_grant.py @@ -0,0 +1,59 @@ + +"""Defines the core data structures for representing approval grants. + +This module includes classes for: +- `ApprovalActor`: Represents an entity (user, agent, tool) involved in an approval. +- `ApprovalEffect`: Enumerates the possible outcomes of an approval (allow, deny, challenge). +- `ApprovalGrant`: Encapsulates the details of a permission grant, including the + effect, actions, resources, grantee, grantor, and optional expiration. +""" +from __future__ import annotations +from datetime import datetime +from enum import Enum +from typing import Literal, Optional + +from pydantic import BaseModel + + +class ApprovalActor(BaseModel): + id: str + """A unique identifier for the actor (e.g., user ID, agent session ID, tool call ID).""" + type: str = Literal["user", "agent", "tool"] + """The type of the actor.""" + on_behalf_of: ApprovalActor | None = None + """The actor on whose behalf this actor is operating, if any (e.g., an agent acting on behalf of a user).""" + + +class ApprovalEffect(str, Enum): + allow = "allow" + """Indicates that the requested action is permitted.""" + deny = "deny" + """Indicates that the requested action is explicitly forbidden.""" + challenge = "challenge" + """Indicates that further information or confirmation is required before allowing or denying.""" + + +ApprovalAction = str +"""Type alias for an action string (e.g., 'tool:read_file', 'agent:use').""" +ApprovalResource = str +"""Type alias for a resource string (e.g., 'tool:files:/path/to/file', 'agent:agent_name').""" + + +class ApprovalGrant(BaseModel): + """Effect the actions on the resources to the grantee by the grantor until the expiration.""" + + effect: Literal[ApprovalEffect.allow, ApprovalEffect.deny] + """The effect of this grant, either allowing or denying the specified actions on the resources.""" + actions: list[ApprovalAction] + """A list of actions (e.g., 'tool:read_file') that this grant permits or denies.""" + resources: list[ApprovalResource] + """A list of resources (e.g., 'tool:files:/path/to/data.txt') to which this grant applies.""" + grantee: ApprovalActor + """The actor (user, agent, or tool) to whom the permissions are granted.""" + grantor: ApprovalActor + """The actor who authorized this grant (e.g., an end-user or a delegating agent).""" + expiration_time: Optional[datetime] = None + """The optional time after which this grant is no longer valid. If None, the grant does not expire.""" + + comment: Optional[str] = None + """An optional comment from the grantor, often used to explain the reason for a denial or to provide context for an approval.""" diff --git a/src/google/adk/approval/approval_handler.py b/src/google/adk/approval/approval_handler.py new file mode 100644 index 0000000000..d3b9b5b2c1 --- /dev/null +++ b/src/google/adk/approval/approval_handler.py @@ -0,0 +1,499 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Handles the logic for managing and evaluating approval requests against policies and grants. + +This module provides the `ApprovalHandler` class, which is responsible for: +- Parsing and storing approval responses. +- Determining if a given function call requires approval based on registered policies and existing grants. +- Generating approval requests (challenges) when necessary. +- Checking if actions on resources are permitted by existing grants. +""" + +from __future__ import annotations + +from typing import Any, Dict, TYPE_CHECKING, Union + +from typing import Literal, Optional + +from google.genai.types import FunctionCall +from pydantic import BaseModel + +from .approval_grant import ApprovalAction, ApprovalActor, ApprovalEffect, ApprovalGrant +from .approval_policy import ApprovalPolicyRegistry, TOOL_NAMESPACE +from .approval_request import ApprovalChallenge, ApprovalDenied, ApprovalRequest, ApprovalResponse, FunctionCallStatus +from ..tools import ToolContext + +if TYPE_CHECKING: + from ..sessions.state import State + + +class ApprovalHandler(object): + """Manages the lifecycle of approval requests, from checking policies to storing grants. + + This class provides static methods to interact with the approval system. It does not + maintain its own state but operates on the state provided to its methods (typically + session state). + """ + + @classmethod + def parse_and_store_approval_responses( + cls, + initial_grants: list[ApprovalGrant], + approval_responses: list[ApprovalResponse], + ) -> list[ApprovalGrant]: + """Parses approval responses and extracts new grants. + + Compares grants from approval responses with existing initial grants to identify + and return only the newly added grants. + + Args: + initial_grants: A list of already existing approval grants. + approval_responses: A list of approval responses, each potentially containing grants. + + Returns: + A list of new `ApprovalGrant` objects that were not present in `initial_grants`. + """ + extra_grants = [] + for approval_response in approval_responses: + added_grants = [ + grant + for grant in approval_response.grants + if grant not in initial_grants and grant not in extra_grants + ] + extra_grants = extra_grants + added_grants + + return extra_grants + + @classmethod + def get_approval_request( + cls, + function_call: FunctionCall, + state: Dict[str, Any], + tool_context: ToolContext, + user_id: str, + session_id: str, + ) -> Optional[Dict[str, Any]]: + """Determines if a function call requires approval and generates a request if so. + + This method checks the given `function_call` against registered approval policies + and existing grants in the `state`. If approval is required (i.e., there are + pending challenges), it updates the `tool_context` to suspend the function call + and request approval. + + If the function call is already approved or doesn't require approval, it updates + the `tool_context` to mark the function call as resumed. + + If a deny grant explicitly forbids the call, an `ApprovalDenied` exception is caught, + and the function call is marked as cancelled. + + Args: + function_call: The `FunctionCall` object to be checked. + state: The current session state, containing existing grants and suspended calls. + tool_context: The context for the current tool call, used to update its status + and request approval. + user_id: The ID of the user initiating the call. + session_id: The ID of the current session. + + Returns: + A dictionary with a "status" key if approval is requested or denied: + - {"status": "approval_requested"} if challenges are pending. + - {"status": "denied", "denied_challenges": ...} if the call is denied. + Returns `None` if the function call can proceed without further approval. + """ + try: + if approval_request := cls._get_pending_challenges( + state=state, + tool_call=function_call, + user_id=user_id, + session_id=session_id, + ): + tool_context.state["approvals__suspended_function_calls"] = ( + FunctionCallStatus.update_status( + tool_context.state.get( + "approvals__suspended_function_calls", [] + ), + function_call, + "suspended", + ) + ) + + tool_context.request_approval(approval_request) + return { + "status": "approval_requested", + } + else: + tool_context.state["approvals__suspended_function_calls"] = ( + FunctionCallStatus.update_status( + tool_context.state.get( + "approvals__suspended_function_calls", [] + ), + function_call, + "resumed", + ) + ) + return None + except ApprovalDenied as e: + tool_context.state["approvals__suspended_function_calls"] = ( + FunctionCallStatus.update_status( + tool_context.state.get("approvals__suspended_function_calls", []), + function_call, + "cancelled", + ) + ) + + return { + "status": "denied", + "denied_challenges": e.denied_challenges, + } + + @classmethod + def _get_pending_challenges( + cls, + state: Union[State, Dict[str, Any]], + tool_call: FunctionCall, + user_id: str, + session_id: str, + ) -> Optional[ApprovalRequest]: + """Checks a tool call against policies and grants to identify pending challenges. + + This method evaluates the `tool_call` against all registered `ApprovalPolicy` + objects relevant to the tool. It then checks existing `ApprovalGrant` objects + in the `state` to see if the required actions on resources are permitted. + + If any required action/resource pair is explicitly denied by a grant, this method + raises an `ApprovalDenied` exception. + + If there are action/resource pairs required by policies that are not covered by + any existing allow grants, these are collected into an `ApprovalRequest`. + + Args: + state: The session state or a dictionary containing at least `approvals__grants`. + tool_call: The `FunctionCall` to be evaluated. + user_id: The ID of the user initiating the call. + session_id: The ID of the current session. + + Returns: + An `ApprovalRequest` object if there are unmet challenges, otherwise `None`. + + Raises: + ApprovalDenied: If an existing grant explicitly denies one of the required + action/resource pairs. + """ + + policies = ApprovalPolicyRegistry.get_tool_policies(tool_call.name) + grants = cls._get_existing_grants(state) + + user_actor = ApprovalActor(id=user_id, type="user") + agent_actor = ApprovalActor( + id=session_id, type="agent", on_behalf_of=user_actor + ) + function_call_actor = ApprovalActor( + id=f"{TOOL_NAMESPACE}:{tool_call.name}:{tool_call.id}", + type="tool", + on_behalf_of=agent_actor, + ) + + approved_policy_pairs = [] + denied_policy_pairs = [] + challenges = [] + + for policy in policies: + unmet_policy_pairs = [] + for policy_action, policy_resource in policy.get_action_resources( + tool_call.args + ): + # Check against function call grantee, then check against the agent and whether delegation is possible + if effect := cls._check_action_on_resource_against_grants( + policy_action, policy_resource, grants, function_call_actor + ): + if effect == ApprovalEffect.deny: + denied_policy_pairs.append((policy_action, policy_resource)) + elif effect == ApprovalEffect.allow: + approved_policy_pairs.append((policy_action, policy_resource)) + continue + unmet_policy_pairs.append((policy_action, policy_resource)) + + grouped_challenges = {} + + for action, resource in unmet_policy_pairs: + if resource not in grouped_challenges: + grouped_challenges[resource] = [action] + else: + grouped_challenges[resource].append(action) + for resource, actions in grouped_challenges.items(): + challenges.append( + ApprovalChallenge( + actions=actions, + resources=[resource], + ) + ) + + if denied_policy_pairs: + raise ApprovalDenied( + denied_challenges=[ + ApprovalChallenge( + grantee=function_call_actor, + actions=[action], + resources=[resource], + ) + for action, resource in denied_policy_pairs + ], + ) + if challenges: + return ApprovalRequest( + function_call=tool_call, + challenges=challenges, + grantee=function_call_actor, + ) + + @classmethod + def _get_existing_grants(cls, state: State) -> list[ApprovalGrant]: + """Retrieves and validates existing approval grants from the state. + + Args: + state: The session state, expected to contain an `approvals__grants` key + with a list of grant dictionaries. + + Returns: + A list of `ApprovalGrant` objects. + """ + return [ + ApprovalGrant.model_validate(grant) + for grant in state.get("approvals__grants", []) + ] + + @staticmethod + def _resource_met(policy_resource: str, grant_resource: str) -> bool: + """Checks if a policy resource string matches a grant resource string, supporting wildcards. + + The `grant_resource` can contain wildcards (`*`) which match any sequence of characters + within a segment, or an entire segment if the segment itself is `*`. + Resource strings are colon-separated (e.g., "namespace:type:identifier"). + + Examples: + - `_resource_met("tool:files:read", "tool:files:*")` -> `True` + - `_resource_met("tool:files:read", "tool:*:read")` -> `True` + - `_resource_met("tool:files:read", "*")` -> `True` + - `_resource_met("foo:bar", "foo:baz")` -> `False` + + Args: + policy_resource: The specific resource being accessed (e.g., "tool:files:/data/my_file.txt"). + grant_resource: The resource pattern from a grant (e.g., "tool:files:*", "*"). + + Returns: + `True` if the `policy_resource` matches the `grant_resource` pattern, `False` otherwise. + """ + # Full wildcard matches anything + if grant_resource == "*": + return True + + # Split resources into segments + policy_segments = policy_resource.split(":") + grant_segments = grant_resource.split(":") + + # If we have different segment counts and not covered by the special case above + if len(policy_segments) != len(grant_segments): + return False + + # Compare each segment + for policy_segment, grant_segment in zip(policy_segments, grant_segments): + if grant_segment == "*": + continue + + if "*" in grant_segment: + prefix, suffix = grant_segment.split("*", maxsplit=1) + if not policy_segment.startswith(prefix): + return False + if not policy_segment.endswith(suffix): + return False + continue + + if policy_segment != grant_segment: + return False + + return True + + @classmethod + def _check_actor(cls, actor: ApprovalActor, grantee: ApprovalActor) -> bool: + """Recursively checks if an `actor` matches a `grantee` definition, including `on_behalf_of`. + + This method verifies that the `id` and `type` of the `actor` match the `grantee`. + If `grantee.on_behalf_of` is set, it recursively checks that `actor.on_behalf_of` + also matches. + + Args: + actor: The `ApprovalActor` requesting access. + grantee: The `ApprovalActor` specified in a grant. + + Returns: + `True` if the actor matches the grantee definition, `False` otherwise. + """ + # Check if the grantee IDs match + if not cls._check_actor_id(actor_id=actor.id, grantee_id=grantee.id): + return False + + # Check if the grantee types match + if actor.type != grantee.type: + return False + + # If on_behalf_of is specified in the grant, check that too + if grantee.on_behalf_of is not None: + if actor.on_behalf_of is None: + return False + + return cls._check_actor(actor.on_behalf_of, grantee.on_behalf_of) + + return True + + @classmethod + def _check_actor_id(cls, actor_id: str, grantee_id: str) -> bool: + """Checks if an actor ID matches a grantee ID, supporting wildcards. + + Similar to `_resource_met`, but for actor IDs. The `grantee_id` can contain + wildcards (`*`) to match parts of or the entire `actor_id`. + + Args: + actor_id: The ID of the actor requesting access. + grantee_id: The ID pattern from a grant. + + Returns: + `True` if the `actor_id` matches the `grantee_id` pattern, `False` otherwise. + """ + if actor_id == grantee_id: + return True + + if grantee_id == "*": + return True + + actor_segments = actor_id.split(":") + grantee_segments = grantee_id.split(":") + + if len(actor_segments) != len(grantee_segments): + return False + + # Compare each segment + for actor_segment, grantee_segment in zip(actor_segments, grantee_segments): + if grantee_segment == "*": + continue + + if actor_segment == grantee_segment: + continue + + if "*" in grantee_segment: + prefix, suffix = grantee_segment.split("*", maxsplit=1) + if not actor_segment.startswith(prefix): + return False + if not actor_segment.endswith(suffix): + return False + continue + + if actor_segment != grantee_segment: + return False + + return True + + @classmethod + def _action_granted_for_resource( + cls, + grant: ApprovalGrant, + policy_action: ApprovalAction, + policy_resource: str, + actor: ApprovalActor, + ) -> Optional[Literal[ApprovalEffect.allow, ApprovalEffect.deny]]: + """Checks if a specific action on a resource by an actor is permitted or denied by a single grant. + + This method verifies: + 1. The `policy_action` is listed in `grant.actions`. + 2. The `actor` matches `grant.grantee` (using `_check_actor`). + 3. The `policy_resource` matches one of the `grant.resources` (using `_resource_met`). + + If all conditions are met, it returns the `grant.effect` (allow or deny). + + Args: + grant: The `ApprovalGrant` to check against. + policy_action: The action being attempted (e.g., "tool:files:read"). + policy_resource: The resource being accessed (e.g., "tool:files:/data/my_doc.txt"). + actor: The `ApprovalActor` attempting the action. + + Returns: + `ApprovalEffect.allow` or `ApprovalEffect.deny` if the grant applies and matches, + otherwise `None`. + """ + # Check if the action is in the grant's actions + if policy_action not in grant.actions: + return None + + # Check if the grantee matches + if not cls._check_actor(actor, grant.grantee): + return None + + # Check if any of the resources match + for grant_resource in grant.resources: + if cls._resource_met(policy_resource, grant_resource): + return grant.effect + + return None + + @classmethod + def _check_action_on_resource_against_grants( + cls, + action: ApprovalAction, + resource: str, + grants: list[ApprovalGrant], + actor: ApprovalActor, + ) -> Optional[Literal[ApprovalEffect.allow, ApprovalEffect.deny]]: + """Evaluates an action on a resource against a list of grants to determine its effective status. + + Deny grants are prioritized. If any deny grant matches the action, resource, and actor, + `ApprovalEffect.deny` is returned immediately. + If no deny grants match, allow grants are checked. If an allow grant matches, + `ApprovalEffect.allow` is returned. + If no grants match, `None` is returned, indicating the action is not explicitly + allowed or denied by the provided grants. + + Args: + action: The `ApprovalAction` being attempted. + resource: The resource string being accessed. + grants: A list of `ApprovalGrant` objects to check against. + actor: The `ApprovalActor` attempting the action. + + Returns: + `ApprovalEffect.allow` if an allow grant matches and no deny grant matches. + `ApprovalEffect.deny` if a deny grant matches. + `None` if no grants explicitly cover the action/resource/actor combination. + """ + # Prioritize deny grants, then check allow grants + allow_grants = [ + grant for grant in grants if grant.effect == ApprovalEffect.allow + ] + deny_grants = [ + grant for grant in grants if grant.effect == ApprovalEffect.deny + ] + + # Check deny grants first (they take precedence) + for grant in deny_grants: + if effect := cls._action_granted_for_resource( + grant, action, resource, actor + ): + return effect + + # Then check allow grants + for grant in allow_grants: + if effect := cls._action_granted_for_resource( + grant, action, resource, actor + ): + return effect + + return None diff --git a/src/google/adk/approval/approval_policy.py b/src/google/adk/approval/approval_policy.py new file mode 100644 index 0000000000..d3799c9dc1 --- /dev/null +++ b/src/google/adk/approval/approval_policy.py @@ -0,0 +1,255 @@ + +"""Defines the structure for approval policies and a registry to manage them. + +This module provides: +- `ApprovalPolicy`: An abstract base class for defining policies that require approval. +- `FunctionToolPolicy`: A concrete policy implementation specifically for tools (functions) + that defines actions and how to map tool arguments to resource strings. +- `ApprovalPolicyRegistry`: A singleton registry to store and retrieve tool policies. +- Decorators and helper functions (`@tool_policy`, `register_policy_for_tool`, + `resource_parameters`, `resource_parameter_map`) for easily defining and registering policies. + +Policies determine what actions on what resources require approval before a tool can be executed. +""" +from typing import Any, Callable, Optional + +from pydantic import BaseModel + +from google.adk.approval.approval_grant import ApprovalAction +from google.adk.tools import BaseTool + + +class ApprovalPolicy(BaseModel): + """Abstract base model for an approval policy. + + An approval policy defines a set of actions that, when applied to resources derived + from tool arguments, require an explicit approval grant. + """ + policy_name: str | None = None + """An optional unique name for the policy, often linking it to a specific tool or agent feature.""" + actions: list[ApprovalAction] + """The list of actions (e.g., 'tool:files:read', 'agent:delegate') that this policy governs.""" + + def get_resources(self, args: dict[str, Any]) -> list[str]: + """Abstract method to derive resource strings from tool arguments. + + Subclasses must implement this to define how tool invocation arguments + map to specific resource identifiers that the policy's actions apply to. + + Args: + args: A dictionary of arguments passed to a tool call. + + Returns: + A list of resource strings. + + Raises: + NotImplementedError: If the subclass does not override this method. + """ + raise NotImplementedError("get_resources not implemented") + + +TOOL_NAMESPACE = "tool" +AGENT_NAMESPACE = "agent" +AGENTS_NAMESPACE = f"{AGENT_NAMESPACE}:agents" + + +class FunctionToolPolicy(ApprovalPolicy): + """A policy specifically designed for function tools. + + It links tool names to actions and provides a mechanism (`resource_mappers`) + to extract resource strings from the arguments passed to the tool. + + Attributes: + resource_mappers: A callable that takes tool arguments (dict) and returns a list of resource strings. + """ + + resource_mappers: Callable[[dict[str, Any]], list[str]] + """A function that maps tool arguments to a list of resource strings. + For example, for a file reading tool, this might extract the file path. + """ + + @property + def tool_name(self) -> Optional[str]: + """Extracts the tool name if the policy_name follows the 'tool:' format.""" + if self.policy_name.startswith(f"{TOOL_NAMESPACE}:"): + return self.policy_name[len(f"{TOOL_NAMESPACE}:") :] + else: + return None + + @staticmethod + def format_name(tool_name: str) -> str: + """Formats a policy name for a tool, ensuring it's prefixed with the tool namespace.""" + return f"{TOOL_NAMESPACE}:{tool_name}" + + def get_resources(self, args: dict[str, Any]) -> list[str]: + """Derives resource strings from tool arguments using the `resource_mappers` function.""" + return self.resource_mappers(args) + + def get_action_resources( + self, args: dict[str, Any] + ) -> list[tuple[ApprovalAction, str]]: + """Generates all action-resource pairs covered by this policy for given tool arguments. + + Args: + args: The arguments passed to the tool call. + + Returns: + A list of tuples, where each tuple is (action_string, resource_string). + """ + return [ + (action, resource) + for action in self.actions + for resource in self.get_resources(args) + ] + + +class ApprovalPolicyRegistry(object): + """A singleton registry for managing `FunctionToolPolicy` instances. + + This registry stores policies associated with tool names, allowing the system + to look up relevant policies when a tool is about to be executed. + """ + + tool_policies: list[FunctionToolPolicy] = [] + """Static list holding all registered `FunctionToolPolicy` instances.""" + + @classmethod + def register_tool_policy(cls, policy: FunctionToolPolicy): + """Registers a `FunctionToolPolicy`. + + Ensures that the policy has a name and adds it to the global list if not already present. + + Args: + policy: The `FunctionToolPolicy` instance to register. + + Raises: + ValueError: If `policy.policy_name` is None. + """ + if policy.policy_name is None: + raise ValueError("Policy name cannot be None") + if policy not in cls.tool_policies: + cls.tool_policies.append(policy) + + @classmethod + def get_tool_policies(cls, tool_name: str) -> list[FunctionToolPolicy]: + """Retrieves all registered policies associated with a given tool name. + + Args: + tool_name: The name of the tool. + + Returns: + A list of `FunctionToolPolicy` instances whose `tool_name` matches. + """ + return [ + policy for policy in cls.tool_policies if policy.tool_name == tool_name + ] + + +register_tool_policy = ApprovalPolicyRegistry.register_tool_policy +"""Alias for `ApprovalPolicyRegistry.register_tool_policy` for convenient access.""" + + +def register_policy_for_tool( + tool: BaseTool | Callable, + policy: FunctionToolPolicy, +): + """Registers a given policy for a specific tool. + + This helper function constructs the correct policy name based on the tool's name + (whether it's a `BaseTool` instance or a callable) and then registers the policy. + + Args: + tool: The tool (either a `BaseTool` subclass or a callable) to associate the policy with. + policy: The `FunctionToolPolicy` to register. The `policy_name` attribute of this + object will be overridden. + """ + if isinstance(tool, BaseTool): + tool_name = tool.name + else: + tool_name = tool.__name__ + policy = FunctionToolPolicy( + policy_name=FunctionToolPolicy.format_name(tool_name=tool_name), + actions=policy.actions, + resource_mappers=policy.resource_mappers, + ) + register_tool_policy(policy) + + +def tool_policy( + actions: list[ApprovalAction], + resources: Callable[[dict[str, Any]], list[str]], +) -> Callable[[BaseTool | Callable], BaseTool | Callable]: + """Decorator to associate an approval policy with a tool function or `BaseTool` class. + + This decorator simplifies the creation and registration of `FunctionToolPolicy` objects. + + Args: + actions: A list of `ApprovalAction` strings that the policy governs. + resources: A callable that takes the tool's arguments (a dict) and returns a list + of resource strings to which the actions apply. + + Returns: + A decorator function that, when applied to a tool, registers a policy for it. + """ + def register(tool): + if isinstance(tool, BaseTool): + tool_name = tool.name + else: + tool_name = tool.__name__ + policy = FunctionToolPolicy( + policy_name=FunctionToolPolicy.format_name(tool_name=tool_name), + actions=actions, + resource_mappers=resources, + ) + register_tool_policy(policy) + return tool + + return register + + +def resource_parameters(namespace: str, parameters: list[str]) -> Callable[[dict[str, Any]], list[str]]: + """Creates a resource mapper function that extracts resource strings from specified tool arguments. + + This is a convenience function for a common pattern where resource identifiers are + directly derived from the values of certain tool parameters, prefixed with a namespace. + + Example: + `resource_parameters("tool:files", ["path"])` would create a mapper that, for a tool call + `my_tool(path="/data/f.txt")`, returns `["tool:files:/data/f.txt"]`. + + Args: + namespace: The namespace string to prefix to each parameter value (e.g., "tool:my_tool_namespace"). + parameters: A list of parameter names from the tool's arguments whose values will be used + to construct resource strings. + + Returns: + A callable suitable for use as the `resources` argument in `@tool_policy` or + the `resource_mappers` attribute of `FunctionToolPolicy`. + """ + mapping = {parameter: (namespace + ":{}").format for parameter in parameters} + return resource_parameter_map(**mapping) + + +def resource_parameter_map(**mapping: Callable[[Any], str]) -> Callable[[dict[str, Any]], list[str]]: + """Creates a resource mapper function based on a direct mapping of argument names to formatters. + + This provides a flexible way to construct resource strings by applying specific formatting + functions to the values of named tool arguments. + + Example: + `resource_parameter_map(filePath=lambda x: f"files:{x}", dirPath=lambda x: f"dirs:{x}")` + If a tool is called with `my_tool(filePath="/a.txt", otherArg=1)`, this mapper would return + `["files:/a.txt"]` (if `dirPath` was not provided). + + Args: + **mapping: Keyword arguments where each key is a tool argument name, and the value is a + callable that takes the argument's value and returns a formatted resource string part. + + Returns: + A callable suitable for use as the `resources` argument in `@tool_policy` or + the `resource_mappers` attribute of `FunctionToolPolicy`. + """ + def resources_map(args): + return [v(args[k]) for k, v in mapping.items()] + + return resources_map diff --git a/src/google/adk/approval/approval_request.py b/src/google/adk/approval/approval_request.py new file mode 100644 index 0000000000..31a4151037 --- /dev/null +++ b/src/google/adk/approval/approval_request.py @@ -0,0 +1,166 @@ + +"""Defines data structures for representing approval requests, challenges, and responses. + +This module includes: +- `ApprovalChallenge`: Represents a specific set of actions on resources that require approval. +- `ApprovalStatus`: (Currently unused internally but defined) Could represent the overall status of an approval. +- `ApprovalDenied`: An exception raised when an approval is explicitly denied. +- `ApprovalRequest`: Encapsulates a function call that is pending approval, along with the + challenges that need to be met and the grantee. +- `ApprovalResponse`: Represents the user's or system's response to an approval request, + typically containing a list of grants. +- `FunctionCallStatus`: Tracks the status (suspended, resumed, cancelled) of a function call + that is subject to the approval workflow. +""" +from __future__ import annotations +from typing import Any, Literal + +from google.genai.types import FunctionCall +from pydantic import BaseModel + +from google.adk.approval.approval_grant import ApprovalAction, ApprovalActor, ApprovalEffect, ApprovalGrant, ApprovalResource + + +class ApprovalChallenge(BaseModel): + """Represents a specific challenge within an approval request. + + A challenge details a set of actions on a list of resources that require + explicit approval before they can be performed. + """ + + actions: list[ApprovalAction] + """The list of actions (e.g., 'tool:files:read') that are being challenged.""" + resources: list[ApprovalResource] + """The list of resources (e.g., 'tool:files:/data/my_file.txt') to which the actions apply.""" + + +class ApprovalStatus(BaseModel): + """Represents the overall status of an approval attempt (e.g., allowed, denied, challenged). + + Note: This class is defined but not actively used by the core approval handler in this version. + It might be used for more complex approval UIs or logging. + + Attributes: + effect: The overall effect of the approval status. + challenges: A list of remaining or satisfied challenges. + """ + effect: ApprovalEffect + challenges: list[ApprovalChallenge] + + +class ApprovalDenied(ValueError): + """Exception raised when an action is explicitly denied by an approval grant. + + Attributes: + denied_challenges: A list of `ApprovalChallenge` objects that were denied. + """ + + def __init__(self, denied_challenges: list[ApprovalChallenge]): + """Initializes the ApprovalDenied exception. + + Args: + denied_challenges: The list of challenges that resulted in the denial. + """ + super().__init__() + self.denied_challenges = denied_challenges + + +class ApprovalRequest(BaseModel): + """Represents a request for approval for a specific function call. + + This is typically generated by the `ApprovalHandler` when a tool call requires + permissions that are not yet granted. + + Attributes: + function_call: The `FunctionCall` that triggered this approval request. + challenges: A list of `ApprovalChallenge` objects detailing what needs to be approved. + grantee: The `ApprovalActor` (typically a tool) for whom the approval is requested. + """ + function_call: FunctionCall + challenges: list[ApprovalChallenge] + grantee: ApprovalActor + """The actor (typically a tool identified by its function call ID) for whom the approval is being sought.""" + + +class ApprovalResponse(BaseModel): + """Represents a response to an `ApprovalRequest`, usually containing new grants. + + This is typically sent by the user or an automated system to grant or deny + the permissions requested in an `ApprovalRequest`. + + Attributes: + grants: A list of `ApprovalGrant` objects resulting from the approval decision. + """ + grants: list[ApprovalGrant] + + +class FunctionCallStatus(BaseModel): + """Tracks the status and sequence of a function call within the approval workflow. + + When a function call requires approval, it is marked as "suspended". Once approval + is granted (or if it's determined no approval is needed), it's "resumed". + If denied, it's "cancelled". The sequence helps in ordering multiple status updates + for the same function call if needed, though typically only the latest status is relevant. + + Attributes: + function_call: The `FunctionCall` whose status is being tracked. + status: The current approval-related status of the function call. + sequence: A monotonically increasing number to order status updates. + """ + function_call: FunctionCall + status: Literal["suspended", "resumed", "cancelled"] + sequence: int + + @classmethod + def get_next_sequence(cls, function_call_statuses: list[dict[str, Any]]) -> int: + """Calculates the next sequence number based on existing statuses. + + Args: + function_call_statuses: A list of dictionaries, where each dictionary + represents a serialized `FunctionCallStatus`. + + Returns: + The next available sequence number (max existing sequence + 1). + """ + return ( + max( + [ + cls.model_validate(fcs_dict).sequence + for fcs_dict in function_call_statuses + ], + default=0, + ) + + 1 + ) + + @classmethod + def update_status( + cls, + function_call_statuses: list[dict[str, Any]], + function_call: FunctionCall, + status: Literal["suspended", "resumed", "cancelled"], + ) -> list[dict[str, Any]]: + """Updates the status list for a given function call. + + Removes any existing status for the specified `function_call` and adds a new + `FunctionCallStatus` with the given `status` and the next sequence number. + + Args: + function_call_statuses: The current list of serialized `FunctionCallStatus` objects. + function_call: The `FunctionCall` to update. + status: The new status for the function call. + + Returns: + A new list of serialized `FunctionCallStatus` objects with the update applied. + """ + return [ + fcs_dict + for fcs_dict in function_call_statuses + if cls.model_validate(fcs_dict).function_call != function_call + ] + [ + cls( + function_call=function_call, + status=status, + sequence=cls.get_next_sequence(function_call_statuses), + ).model_dump(mode="json") + ] diff --git a/src/google/adk/approval/approval_request_processor.py b/src/google/adk/approval/approval_request_processor.py new file mode 100644 index 0000000000..799e258dc3 --- /dev/null +++ b/src/google/adk/approval/approval_request_processor.py @@ -0,0 +1,307 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Provides an LLM request preprocessor for handling the approval workflow. + +This module defines `_ApprovalLlmRequestProcessor`, which integrates into the LLM +request lifecycle to: +1. Check for incoming approval responses from the user/client. +2. Update the session state with new grants based on these responses. +3. Identify and prepare previously suspended function calls (due to pending approvals) + that can now be resumed with the new grants. +4. Generate necessary events to reflect grant updates and to trigger the execution + of resumed function calls. + +It also includes helper functions for extracting approval-related information from events. +""" + +from __future__ import annotations + +from copy import deepcopy +from typing import Any, AsyncGenerator, Optional, TYPE_CHECKING, Tuple + +from google.genai import types +from typing_extensions import override + +from .approval_grant import ApprovalGrant +from .approval_handler import ApprovalHandler +from .approval_request import ApprovalRequest, ApprovalResponse, FunctionCallStatus +from ..agents.invocation_context import InvocationContext +from ..events import EventActions +from ..events.event import Event +from ..flows.llm_flows import functions +from ..flows.llm_flows._base_llm_processor import BaseLlmRequestProcessor +from ..models.llm_request import LlmRequest +from ..sessions import State + +if TYPE_CHECKING: + from ..agents.llm_agent import LlmAgent + +# The name of the function call for requesting approval +REQUEST_APPROVAL_FUNCTION_CALL_NAME = "adk_request_approval" + + +class _ApprovalLlmRequestProcessor(BaseLlmRequestProcessor): + """Processes incoming approval responses and manages suspended function calls. + + This preprocessor runs before the main LLM content generation. It inspects the latest + user event for any `ApprovalResponse` function responses. If found, it updates the + session's grants. Then, it checks if any function calls were previously suspended + pending these approvals. If such calls exist and can now proceed (or are still + partially pending), it prepares them for re-evaluation or execution. + """ + + @override + async def run_async( + self, invocation_context: InvocationContext, llm_request: LlmRequest + ) -> AsyncGenerator[Event, None]: + """Executes the approval processing logic. + + Checks for approval responses in the last user event, updates grants, and + resumes or further processes suspended function calls. + + Args: + invocation_context: The current invocation context, providing access to the + session, agent, and other invocation details. + llm_request: The current LLM request object (not directly modified by this + processor but part of the interface). + + Yields: + `Event` objects representing state changes (grant updates) or new function + call processing events if calls are resumed. + """ + from ..agents.llm_agent import LlmAgent + + agent = invocation_context.agent + if not isinstance(agent, LlmAgent): + return + + if not (events := invocation_context.session.events): + return + + state = invocation_context.session.state + + if not (approval_responses := self._get_approval_responses(events)): + return + + if grant_update_actions := self._process_approvals( + approval_responses, state + ): + yield Event( + invocation_id=invocation_context.invocation_id, + author=invocation_context.agent.name, + actions=grant_update_actions, + branch=invocation_context.branch, + ) + state.update(grant_update_actions.state_delta) + else: + invocation_context.end_invocation = True + return + + if function_calls_content := self._get_suspended_function_calls_content( + state + ): + # Reset suspended_function_calls in state (need to ensure a state_delta event is sent later to track this) + state["approvals__suspended_function_calls"] = [] + + function_calls_event = Event( + invocation_id=invocation_context.invocation_id, + author=invocation_context.agent.name, + branch=invocation_context.branch, + content=function_calls_content, + ) + + if function_response_event := await functions.handle_function_calls_async( + invocation_context, + function_calls_event, + {tool.name: tool for tool in await agent.canonical_tools()}, + ): + if "approvals" not in function_response_event.actions.state_delta: + function_response_event.actions.state_delta["approvals"] = {} + + # reset suspended_function_calls if no state_deltas created - all were passed through grant testing logic again + # function_response_event.actions.state_delta["approvals__suspended_function_calls"] = function_response_event.actions.state_delta.get("approvals__suspended_function_calls") + + yield function_response_event + state.update(function_response_event.actions.state_delta) + + def _get_approval_responses( + self, events: list[Event] + ) -> list[ApprovalResponse]: + """Extracts `ApprovalResponse` objects from the last user event. + + Searches the most recent event authored by the "user" for function responses + named `REQUEST_APPROVAL_FUNCTION_CALL_NAME` and parses them into + `ApprovalResponse` model instances. + + Args: + events: A list of all events in the current session, ordered chronologically. + + Returns: + A list of `ApprovalResponse` objects found, or an empty list if none are present. + """ + if not (user_event := get_last_user_response(events)): + return [] + + return [ + ApprovalResponse.model_validate(function_call_response.response) + for function_call_response in user_event.get_function_responses() + if function_call_response.name == REQUEST_APPROVAL_FUNCTION_CALL_NAME + ] + + def _process_approvals( + self, approval_responses: list[ApprovalResponse], state: dict[str, Any] + ) -> Optional[EventActions]: + """Processes approval responses to update grants in the state. + + Compares grants from `approval_responses` with existing grants in `state` + and generates `EventActions` to update `state["approvals__grants"]` if + new grants are found. + + Args: + approval_responses: A list of `ApprovalResponse` objects from the user. + state: The current session state dictionary. + + Returns: + An `EventActions` object with a `state_delta` for grant updates if new + grants were processed, otherwise `None`. + """ + initial_grants = self._get_grants(state) + if extra_grants := ApprovalHandler.parse_and_store_approval_responses( + initial_grants=initial_grants, + approval_responses=approval_responses, + ): + return EventActions( + state_delta={ + "approvals__grants": [ + grant.model_dump(mode="json") + for grant in initial_grants + extra_grants + ] + }, + ) + + def _get_suspended_function_calls_content( + self, state: dict[str, Any] + ) -> Optional[types.Content]: + """Creates a `types.Content` object containing currently suspended function calls. + + Retrieves function calls from `state["approvals__suspended_function_calls"]` that + have a status of "suspended" and packages them into a `types.Content` object, + which can be used to re-trigger their processing. + + Args: + state: The current session state dictionary. + + Returns: + A `types.Content` object with parts for each suspended function call, or + `None` if no calls are currently suspended. + """ + suspended_function_calls = self._get_suspended_function_calls(state) + # align with existing function calls? Or no point? + # Create new event with function calls? + return types.Content( + role="model", + parts=[ + types.Part(function_call=function_call) + for function_call in suspended_function_calls + ], + ) + + def _get_grants(self, state: dict[str, Any]) -> list[ApprovalGrant]: + """Retrieves and validates `ApprovalGrant` objects from the state. + + Args: + state: The current session state dictionary. + + Returns: + A list of `ApprovalGrant` model instances. + """ + return [ + ApprovalGrant.model_validate(grant_dict) + for grant_dict in state.get("approvals__grants", []) + ] + + def _get_suspended_function_calls(self, state: dict[str, Any]) -> list[types.FunctionCall]: + """Extracts unique, currently suspended function calls from the state. + + Filters `state["approvals__suspended_function_calls"]` for entries with status + "suspended" and returns a list of unique `types.FunctionCall` objects. + Uniqueness is based on the function call ID. + + Args: + state: The current session state dictionary. + + Returns: + A list of unique `types.FunctionCall` objects that are suspended. + """ + function_calls = [ + FunctionCallStatus.model_validate(suspended_function_call).function_call + for suspended_function_call in state.get( + "approvals__suspended_function_calls", [] + ) + if FunctionCallStatus.model_validate(suspended_function_call).status == "suspended" + ] + function_calls_dict = {fc.id: fc for fc in function_calls} + return list(function_calls_dict.values()) + + +# Create the preprocessor instance +request_processor = _ApprovalLlmRequestProcessor() + + +def get_last_user_response(events: list[Event]) -> Optional[Event]: + """Retrieves the most recent user event containing function responses. + + Iterates backwards through the event list to find the latest event authored + by "user" that also contains at least one function response. + + Args: + events: The list of all session events. + + Returns: + The last relevant user `Event`, or `None` if not found. + """ + user_events = list(get_user_responses(events, max_events=1)) + if len(user_events) == 1: + return user_events[0] + + +def get_user_responses(events: list[Event], max_events: Optional[int]=None) -> Generator[Event, None, None]: + """Yields user events containing function responses, from most recent. + + Iterates backwards through the event list, yielding events that are authored + by "user" and contain function responses. + + Args: + events: The list of all session events. + max_events: Optional limit on the number of events to yield. + + Yields: + `Event` objects that meet the criteria. + """ + events_emitted = 0 + for k in range(len(events) - 1, -1, -1): + event = events[k] + # Look for first event authored by user + if not event.author or event.author != "user": + continue + + responses = event.get_function_responses() + if not responses: + return + + yield event + events_emitted += 1 + if max_events is not None and events_emitted >= max_events: + break diff --git a/src/google/adk/events/event_actions.py b/src/google/adk/events/event_actions.py index 994a7900b9..a8becc6cf3 100644 --- a/src/google/adk/events/event_actions.py +++ b/src/google/adk/events/event_actions.py @@ -21,6 +21,7 @@ from pydantic import ConfigDict from pydantic import Field +from ..approval.approval_request import ApprovalRequest from ..auth.auth_tool import AuthConfig @@ -64,3 +65,16 @@ class EventActions(BaseModel): identify the function call. - Values: The requested auth config. """ + + requested_approvals: list[ApprovalRequest] = Field(default_factory=list) + """A list of `ApprovalRequest` objects generated by tools during their execution. + + When a tool determines that it requires specific approvals to proceed (e.g., due to + its internal logic or policies it's aware of, though typically policies are checked + by `ApprovalHandler` before tool execution), it can populate this list. + The `ApprovalHandler` or a flow can then use these requests to initiate an + approval workflow with the user. + This is primarily used when a tool call is made, and the `ApprovalHandler` determines + that one or more challenges exist, leading to an `ApprovalRequest` being added here + via `ToolContext.request_approval()`. + """ diff --git a/src/google/adk/flows/llm_flows/base_llm_flow.py b/src/google/adk/flows/llm_flows/base_llm_flow.py index 788f1aa4f1..1a196a9d92 100644 --- a/src/google/adk/flows/llm_flows/base_llm_flow.py +++ b/src/google/adk/flows/llm_flows/base_llm_flow.py @@ -456,6 +456,16 @@ async def _postprocess_handle_function_calls_async( if function_response_event := await functions.handle_function_calls_async( invocation_context, function_call_event, llm_request.tools_dict ): + # If the function responses triggered any approval requests (e.g., a tool call was suspended), + # generate an event to send these approval requests to the client/user. + approval_event = functions.generate_approval_event( + # If the function responses triggered any approval requests (e.g., a tool call was suspended), + # generate an event to send these approval requests to the client/user. + invocation_context, function_response_event + ) + if approval_event: + yield approval_event + auth_event = functions.generate_auth_event( invocation_context, function_response_event ) diff --git a/src/google/adk/flows/llm_flows/contents.py b/src/google/adk/flows/llm_flows/contents.py index b37d8aff3e..07913d1afe 100644 --- a/src/google/adk/flows/llm_flows/contents.py +++ b/src/google/adk/flows/llm_flows/contents.py @@ -26,7 +26,7 @@ from ...events.event import Event from ...models.llm_request import LlmRequest from ._base_llm_processor import BaseLlmRequestProcessor -from .functions import remove_client_function_call_id +from .functions import REQUEST_APPROVAL_FUNCTION_CALL_NAME, remove_client_function_call_id from .functions import REQUEST_EUC_FUNCTION_CALL_NAME @@ -172,7 +172,7 @@ def _rearrange_events_for_latest_function_response( function_responses = event.get_function_responses() if ( function_responses - and function_responses[0].id in function_responses_ids + and any([function_response.id in function_responses_ids for function_response in function_responses]) ): function_response_events.append(event) function_response_events.append(events[-1]) @@ -218,6 +218,10 @@ def _get_contents( if _is_auth_event(event): # skip auth event continue + if _is_approval_event(event): + # Skip approval events as they are part of the approval flow + # and not direct conversational content for the LLM history. + continue filtered_events.append( _convert_foreign_event(event) if _is_other_agent_reply(agent_name, event) @@ -394,3 +398,31 @@ def _is_auth_event(event: Event) -> bool: ): return True return False + + +def _is_approval_event(event: Event) -> bool: + """Checks if an event is related to the approval workflow. + + An event is considered an approval event if it contains a function call or + function response part where the function name is `REQUEST_APPROVAL_FUNCTION_CALL_NAME`. + + Args: + event: The `Event` to check. + + Returns: + `True` if the event is an approval-related event, `False` otherwise. + """ + if not event.content.parts: + return False + for part in event.content.parts: + if ( + part.function_call + and part.function_call.name == REQUEST_APPROVAL_FUNCTION_CALL_NAME + ): + return True + if ( + part.function_response + and part.function_response.name == REQUEST_APPROVAL_FUNCTION_CALL_NAME + ): + return True + return False diff --git a/src/google/adk/flows/llm_flows/functions.py b/src/google/adk/flows/llm_flows/functions.py index 53ddb3564d..9416535bfa 100644 --- a/src/google/adk/flows/llm_flows/functions.py +++ b/src/google/adk/flows/llm_flows/functions.py @@ -29,9 +29,12 @@ from ...agents.active_streaming_tool import ActiveStreamingTool from ...agents.invocation_context import InvocationContext +from ...approval.approval_handler import ApprovalHandler +from ...approval.approval_request_processor import REQUEST_APPROVAL_FUNCTION_CALL_NAME from ...auth.auth_tool import AuthToolArguments from ...events.event import Event from ...events.event_actions import EventActions +from ...sessions import State from ...telemetry import trace_tool_call from ...telemetry import trace_tool_response from ...telemetry import tracer @@ -123,6 +126,55 @@ def generate_auth_event( ) +def generate_approval_event( + invocation_context: InvocationContext, + function_response_event: Event, +) -> Optional[Event]: + """Generates an event to forward approval requests to the client. + + If the `function_response_event` (typically resulting from tool execution attempts + or `ApprovalHandler` checks) contains `ApprovalRequest` objects in its + `actions.requested_approvals` field, this function constructs a new `Event`. + This new event will contain function call parts, each named + `REQUEST_APPROVAL_FUNCTION_CALL_NAME`, with the `ApprovalRequest` details as arguments. + This event is intended to be sent to the client to solicit user approval. + + Args: + invocation_context: The current invocation context, used for authoring the new event. + function_response_event: The event (usually from `handle_function_calls_async` or + as a result of `ApprovalHandler.get_approval_request`) + that might contain `requested_approvals` in its `actions`. + + Returns: + A new `Event` formatted to request approvals if `requested_approvals` is present + and non-empty in the input event's actions; otherwise, `None`. + """ + if not function_response_event.actions.requested_approvals: + return None + + parts = [] + long_running_tool_ids = set() + for approval_request in function_response_event.actions.requested_approvals: + + request_approval_function_call = types.FunctionCall( + name=REQUEST_APPROVAL_FUNCTION_CALL_NAME, + args=approval_request.model_dump(exclude_none=True), + ) + request_approval_function_call.id = generate_client_function_call_id() + long_running_tool_ids.add(request_approval_function_call.id) + parts.append(types.Part(function_call=request_approval_function_call)) + + return Event( + invocation_id=invocation_context.invocation_id, + author=invocation_context.agent.name, + branch=invocation_context.branch, + content=types.Content( + parts=parts, role=function_response_event.content.role + ), + long_running_tool_ids=long_running_tool_ids, + ) + + async def handle_function_calls_async( invocation_context: InvocationContext, function_call_event: Event, @@ -151,6 +203,7 @@ async def handle_function_calls_async( # do not use "args" as the variable name, because it is a reserved keyword # in python debugger. function_args = function_call.args or {} + function_response: Optional[dict] = None for callback in agent.canonical_before_tool_callbacks: @@ -162,6 +215,17 @@ async def handle_function_calls_async( if function_response: break + # Check if an approval is required *before* attempting to call the tool. + # This allows the system to suspend a tool call if its policies are not met by existing grants. + if not function_response: + function_response = ApprovalHandler.get_approval_request( + function_call, + invocation_context.session.state, + tool_context, + user_id=invocation_context.user_id, + session_id=invocation_context.session.id, + ) + if not function_response: function_response = await __call_tool_async( tool, args=function_args, tool_context=tool_context @@ -206,7 +270,7 @@ async def handle_function_calls_async( event_id=merged_event.id, function_response_event=merged_event, ) - return merged_event + return merged_event # TODO work out when / how to include a suitable state delta for clearing the suspended tool calls async def handle_function_calls_live( @@ -489,12 +553,55 @@ def merge_parallel_function_response_events( merged_actions = EventActions() merged_requested_auth_configs = {} + merged_requested_approvals = [] + # Merge approval-related state delta fields + merged_state_delta_approvals_grants = None + merged_state_delta_approvals_suspended_function_calls = None for event in function_response_events: merged_requested_auth_configs.update(event.actions.requested_auth_configs) + merged_requested_approvals.extend(event.actions.requested_approvals) + # Consolidate 'approvals__grants' from all event state_deltas + if 'approvals__grants' in event.actions.state_delta: + if merged_state_delta_approvals_grants is None: + merged_state_delta_approvals_grants = [] + merged_state_delta_approvals_grants.extend( + event.actions.state_delta['approvals__grants'] + ) + # Consolidate and correctly order 'approvals__suspended_function_calls' + if 'approvals__suspended_function_calls' in event.actions.state_delta: + if merged_state_delta_approvals_suspended_function_calls is None: + merged_state_delta_approvals_suspended_function_calls = [] + merged_state_delta_approvals_suspended_function_calls.extend( + event.actions.state_delta['approvals__suspended_function_calls'] + ) merged_actions = merged_actions.model_copy( update=event.actions.model_dump() ) merged_actions.requested_auth_configs = merged_requested_auth_configs + merged_actions.requested_approvals = merged_requested_approvals + + if merged_state_delta_approvals_grants is not None: + # Ensure uniqueness if grants were somehow duplicated across events (though unlikely for new grants) + # For simplicity here, we assume they are additive and already unique if coming from ApprovalHandler. + merged_actions.state_delta['approvals__grants'] = ( + merged_state_delta_approvals_grants + ) + if merged_state_delta_approvals_suspended_function_calls is not None: + # Ensure suspended function calls are unique by ID and ordered by sequence + # The sort ensures that if multiple updates for the same call ID exist (e.g., suspended then resumed + # within the batch, though unlikely), the latest sequence one is kept by the dict comprehension. + merged_state_delta_approvals_suspended_function_calls = list( + { + fcs['function_call']['id']: fcs + for fcs in sorted( + merged_state_delta_approvals_suspended_function_calls, + key=lambda fcs: fcs['sequence'], + ) + }.values() + ) + merged_actions.state_delta['approvals__suspended_function_calls'] = ( + merged_state_delta_approvals_suspended_function_calls + ) # Create the new merged event merged_event = Event( invocation_id=Event.new_id(), diff --git a/src/google/adk/flows/llm_flows/single_flow.py b/src/google/adk/flows/llm_flows/single_flow.py index 787a767972..919c90f00f 100644 --- a/src/google/adk/flows/llm_flows/single_flow.py +++ b/src/google/adk/flows/llm_flows/single_flow.py @@ -22,6 +22,7 @@ from . import contents from . import identity from . import instructions +from ...approval import approval_request_processor from ...auth import auth_preprocessor from .base_llm_flow import BaseLlmFlow @@ -29,16 +30,29 @@ class SingleFlow(BaseLlmFlow): - """SingleFlow is the LLM flows that handles tools calls. + """SingleFlow handles tool calls and basic LLM interactions for an agent. - A single flow only consider an agent itself and tools. - No sub-agents are allowed for single flow. + It processes requests by running them through a series of preprocessors: + - Basic content and history management. + - Approval request processing: Manages the approval lifecycle for tool calls, + handling incoming approval grants and resuming suspended calls before they + are sent to the LLM or for execution if already approved. + - Authentication preprocessing: Handles auth challenges if tools require them. + - Instruction processing: Incorporates system instructions. + - Identity processing: Adds agent identity information. + - Content finalization: Prepares the final content for the LLM. + - Natural Language Planning preprocessing. + - Code execution related preprocessing. + + After the LLM call, response processors for NL Planning and code execution are run. + This flow does not involve sub-agents. """ def __init__(self): super().__init__() self.request_processors += [ basic.request_processor, + approval_request_processor.request_processor, auth_preprocessor.request_processor, instructions.request_processor, identity.request_processor, diff --git a/src/google/adk/tools/agent_tool.py b/src/google/adk/tools/agent_tool.py index 2b23dcf573..8c13abd6bc 100644 --- a/src/google/adk/tools/agent_tool.py +++ b/src/google/adk/tools/agent_tool.py @@ -21,6 +21,8 @@ from pydantic import model_validator from typing_extensions import override +from ..approval.approval_grant import ApprovalAction +from ..approval.approval_policy import FunctionToolPolicy, TOOL_NAMESPACE, register_policy_for_tool from . import _automatic_function_calling_util from ..memory.in_memory_memory_service import InMemoryMemoryService from ..runners import Runner @@ -176,3 +178,54 @@ async def run_async( [p.text for p in last_event.content.parts if p.text] ) return tool_result + + +DEFAULT_APPROVED_AGENT_TOOL_POLICIES = [ + FunctionToolPolicy( + actions=[ApprovalAction(f'{TOOL_NAMESPACE}:agent:use')], + resource_mappers=lambda args: ['*'], + ) +] +"""Default approval policies for an `AgentTool` if it were to be approved by default. +This is not used by `AgentTool` itself but provided as a possible default for `ApprovedAgentTool`. +It grants a blanket permission for an agent to use any other agent tool. +More specific policies should be used in practice. +""" + + +class ApprovedAgentTool(AgentTool): + """An `AgentTool` that integrates with the approval system by registering policies. + + This tool allows an agent to call another agent. It automatically registers + approval policies upon initialization, typically requiring an 'agent:use' action + on a resource representing the target agent. + """ + + def __init__( + self, + agent: BaseAgent, + skip_summarization: bool = False, + policies: list[FunctionToolPolicy] = None, + ): + """Initializes the ApprovedAgentTool and registers its approval policies. + + Args: + agent: The `BaseAgent` instance that this tool will invoke. + skip_summarization: If True, the full response from the sub-agent is returned. + Otherwise, the response may be summarized. + policies: An optional list of `FunctionToolPolicy` objects to register for this + agent tool. If None, a default policy is created that requires the + action 'tool:agent:use' on the resource 'tool:agents:'. + """ + super().__init__(agent=agent, skip_summarization=skip_summarization) + if policies is None: + policies = [ + FunctionToolPolicy( + actions=[ApprovalAction(f'{TOOL_NAMESPACE}:agent:use')], + resource_mappers=lambda args: [ + f'{TOOL_NAMESPACE}:agents:{self.name}' + ], + ) + ] + for policy in policies: + register_policy_for_tool(self, policy) diff --git a/src/google/adk/tools/crewai_tool.py b/src/google/adk/tools/crewai_tool.py index db4c533d21..3d65c452bc 100644 --- a/src/google/adk/tools/crewai_tool.py +++ b/src/google/adk/tools/crewai_tool.py @@ -19,6 +19,7 @@ from . import _automatic_function_calling_util from .function_tool import FunctionTool +from ..approval.approval_policy import FunctionToolPolicy, register_policy_for_tool try: from crewai.tools import BaseTool as CrewaiBaseTool @@ -45,7 +46,25 @@ class CrewaiTool(FunctionTool): tool: CrewaiBaseTool """The wrapped CrewAI tool.""" - def __init__(self, tool: CrewaiBaseTool, *, name: str, description: str): + def __init__( + self, + tool: CrewaiBaseTool, + *, + name: str, + description: str, + policies: list[FunctionToolPolicy] = None, + ): + """Initializes a CrewaiTool, wrapping a CrewAI BaseTool, and registers policies. + + Args: + tool: The CrewAI `BaseTool` instance to wrap. + name: The name for this tool. If empty, it attempts to use `tool.name` + (replacing spaces with underscores and lowercasing). + description: The description for this tool. If empty, it attempts to use + `tool.description`. + policies: An optional list of `FunctionToolPolicy` objects to register for this tool. + Each policy defines actions and resource mappings for approval. + """ super().__init__(tool.run) self.tool = tool if name: @@ -58,6 +77,8 @@ def __init__(self, tool: CrewaiBaseTool, *, name: str, description: str): self.description = description elif tool.description: self.description = tool.description + for policy in policies or []: + register_policy_for_tool(self, policy) @override def _get_declaration(self) -> types.FunctionDeclaration: diff --git a/src/google/adk/tools/langchain_tool.py b/src/google/adk/tools/langchain_tool.py index b36c3f5e6c..b9cc5687d2 100644 --- a/src/google/adk/tools/langchain_tool.py +++ b/src/google/adk/tools/langchain_tool.py @@ -23,6 +23,7 @@ from . import _automatic_function_calling_util from .function_tool import FunctionTool +from ..approval.approval_policy import FunctionToolPolicy, register_policy_for_tool class LangchainTool(FunctionTool): @@ -57,7 +58,24 @@ def __init__( tool: Union[BaseTool, object], name: Optional[str] = None, description: Optional[str] = None, + policies: list[FunctionToolPolicy] = None, ): + """Initializes a LangchainTool, wrapping a Langchain tool, and registers policies. + + The wrapped tool can be a Langchain `BaseTool` subclass or any object with a + `run` (or `_run`) method and a `pydantic_schema` (or `args_schema`, or `tool_args`). + + Args: + tool: The Langchain tool or compatible object to wrap. + name: Optional name for the tool. If None, it's inferred from `tool.name`. + description: Optional description. If None, inferred from `tool.description`. + policies: An optional list of `FunctionToolPolicy` objects to register for this tool. + Each policy defines actions and resource mappings for approval. + + Raises: + ValueError: If the provided tool object does not have a callable 'run' or '_run' method, + or if its argument schema cannot be determined. + """ # Check if the tool has a 'run' method if not hasattr(tool, 'run') and not hasattr(tool, '_run'): raise ValueError("Langchain tool must have a 'run' or '_run' method") @@ -82,6 +100,9 @@ def __init__( self.description = tool.description # else: keep default from FunctionTool + for policy in policies or []: + register_policy_for_tool(self, policy) + @override def _get_declaration(self) -> types.FunctionDeclaration: """Build the function declaration for the tool. diff --git a/src/google/adk/tools/tool_context.py b/src/google/adk/tools/tool_context.py index e99d42caaa..5c9ae1cff5 100644 --- a/src/google/adk/tools/tool_context.py +++ b/src/google/adk/tools/tool_context.py @@ -14,13 +14,17 @@ from __future__ import annotations -from typing import Optional +from typing import Any, Dict, Optional from typing import TYPE_CHECKING from ..agents.callback_context import CallbackContext +from ..approval.approval_request import ApprovalRequest from ..auth.auth_credential import AuthCredential from ..auth.auth_handler import AuthHandler from ..auth.auth_tool import AuthConfig +# from ..approval.approval_credential import ApprovalCredential, ApprovalStatus, ApprovalType +# from ..approval.approval_handler import ApprovalHandler +# from ..approval.approval_tool import ApprovalConfig if TYPE_CHECKING: from ..agents.invocation_context import InvocationContext @@ -69,6 +73,38 @@ def request_credential(self, auth_config: AuthConfig) -> None: def get_auth_response(self, auth_config: AuthConfig) -> AuthCredential: return AuthHandler(auth_config).get_auth_response(self.state) + def request_approval( + self, + approval_request: ApprovalRequest, + ) -> None: + """Allows a tool to explicitly request an approval during its execution. + + When a tool, during its `run_async` or `_call_live` method, determines that + it needs a specific permission that might not have been covered by initial + policy checks (or for dynamic conditions), it can call this method. + + This adds the provided `ApprovalRequest` to the `_event_actions.requested_approvals` + list. The `ApprovalHandler` or the flow (e.g., `generate_approval_event` in + `llm_flows.functions`) will then typically process these requests, potentially + suspending further execution and prompting the user. + + Note: This is a lower-level mechanism. Typically, approvals are checked *before* + a tool is run, based on registered `ApprovalPolicy` objects. This method is for + cases where a tool itself needs to initiate an approval mid-flight. + + Args: + approval_request: An `ApprovalRequest` object detailing the function call, + challenges, and grantee for the requested approval. + + Raises: + ValueError: If `self.function_call_id` is not set, which is necessary to + associate the approval request with the correct tool invocation. + """ + if not self.function_call_id: + raise ValueError('function_call_id is not set.') + + self._event_actions.requested_approvals.append(approval_request) + async def list_artifacts(self) -> list[str]: """Lists the filenames of the artifacts attached to the current session.""" if self._invocation_context.artifact_service is None: diff --git a/tests/unittests/approval/__init__.py b/tests/unittests/approval/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/unittests/approval/test_approval_grant.py b/tests/unittests/approval/test_approval_grant.py new file mode 100644 index 0000000000..0dc6f30471 --- /dev/null +++ b/tests/unittests/approval/test_approval_grant.py @@ -0,0 +1,64 @@ +import uuid + +import pytest + +from google.adk.approval.approval_grant import ApprovalAction, ApprovalActor, ApprovalGrant +from google.adk.approval.approval_policy import ApprovalPolicyRegistry, resource_parameters, tool_policy +from google.adk.tools import FunctionTool, ToolContext + + +@pytest.fixture() +def read_file_tool(): + LOCAL_FILE_NAMESPACE = "tool:local_file" + LOCAL_FILE_READ_PERMISSION = ApprovalAction(f"{LOCAL_FILE_NAMESPACE}:read") + + @tool_policy( + actions=[LOCAL_FILE_READ_PERMISSION], + resources=resource_parameters(LOCAL_FILE_NAMESPACE, ["path"]), + ) + def read_file(*, path: str): + try: + with open(path, "r") as f: + return f.read() + except Exception as e: + return str(e) + + return read_file + + +def test_approval_grant__once(read_file_tool, tmp_path): + tool = FunctionTool(read_file_tool) + + user_id = str(uuid.uuid4()) + session_id = str(uuid.uuid4()) + function_call_id = str(uuid.uuid4()) + path = str(tmp_path / "file-to-read.txt") + args = dict(path=path) + + policies = ApprovalPolicyRegistry.get_tool_policies(tool.name) + assert len(policies) == 1 + + policy = policies[0] + grant = ApprovalGrant( + effect="allow", + actions=policy.actions, + resources=policy.get_resources(args), + grantee=ApprovalActor( + id=f"{tool.name}:{function_call_id}", + type="tool", + on_behalf_of=ApprovalActor( + id=session_id, + type="agent", + on_behalf_of=ApprovalActor( + id=user_id, + type="user", + ), + ), + ), + grantor=ApprovalActor( + id=user_id, + type="user", + ), + ) + + assert grant diff --git a/tests/unittests/approval/test_approval_handler.py b/tests/unittests/approval/test_approval_handler.py new file mode 100644 index 0000000000..d7f533870a --- /dev/null +++ b/tests/unittests/approval/test_approval_handler.py @@ -0,0 +1,803 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import pytest +from google.genai.types import FunctionCall + +from google.adk.approval.approval_grant import ApprovalAction, ApprovalActor, ApprovalEffect, ApprovalGrant +from google.adk.approval.approval_handler import ApprovalHandler +from google.adk.approval.approval_policy import FunctionToolPolicy, register_tool_policy +from google.adk.approval.approval_request import ApprovalDenied, ApprovalResponse +from google.adk.sessions import State + + +class TestResourceMet: + """Tests for the _resource_met method.""" + + def test_exact_match(self): + """Test when resources match exactly.""" + assert ( + ApprovalHandler._resource_met("test:resource", "test:resource") is True + ) + + def test_no_match(self): + """Test when resources don't match at all.""" + assert ( + ApprovalHandler._resource_met("test:resource1", "test:resource2") + is False + ) + + def test_wildcard_prefix(self): + """Test wildcard at the beginning of grant resource.""" + assert ApprovalHandler._resource_met("test:resource", "*:resource") is True + assert ApprovalHandler._resource_met("other:resource", "*:resource") is True + assert ApprovalHandler._resource_met("test:other", "*:resource") is False + + def test_wildcard_suffix(self): + """Test wildcard at the end of grant resource.""" + assert ApprovalHandler._resource_met("test:resource", "test:*") is True + assert ApprovalHandler._resource_met("test:other", "test:*") is True + assert ApprovalHandler._resource_met("other:resource", "test:*") is False + + def test_wildcard_middle(self): + """Test wildcard in the middle of grant resource.""" + assert ( + ApprovalHandler._resource_met("test:abc:resource", "test:*:resource") + is True + ) + assert ( + ApprovalHandler._resource_met("test:xyz:resource", "test:*:resource") + is True + ) + assert ( + ApprovalHandler._resource_met("test:abc:other", "test:*:resource") + is False + ) + + def test_full_wildcard(self): + """Test full wildcard grant resource.""" + assert ApprovalHandler._resource_met("test:resource", "*") is True + assert ApprovalHandler._resource_met("any:thing:at:all", "*") is True + + def test_multiple_wildcards(self): + """Test grant resource with multiple wildcards.""" + assert ( + ApprovalHandler._resource_met("test:abc:resource", "test:*:*") is True + ) + assert ApprovalHandler._resource_met("test:abc:xyz", "test:*:*") is True + assert ( + ApprovalHandler._resource_met("other:abc:resource", "test:*:*") is False + ) + + def test_wildcard_segments(self): + """Test wildcard matching entire segments.""" + # Wildcard matching a single segment + assert ( + ApprovalHandler._resource_met("test:abc:resource", "test:*:resource") + is True + ) + + # Wildcard matching multiple segments + assert ( + ApprovalHandler._resource_met( + "test:abc:def:resource", "test:*:resource" + ) + is False + ) + assert ( + ApprovalHandler._resource_met("test:resource", "test:*:resource") + is False + ) + + def test_complex_patterns(self): + """Test more complex wildcard patterns.""" + assert ( + ApprovalHandler._resource_met( + "tool_call:read_file:/path/to/file", "tool_call:read_file:*" + ) + is True + ) + assert ( + ApprovalHandler._resource_met( + "tool_call:read_file:/path/to/file", "tool_call:*:/path/to/*" + ) + is True + ) + assert ( + ApprovalHandler._resource_met( + "tool_call:write_file:/path/to/file", "tool_call:read_*:*" + ) + is False + ) + + +@pytest.fixture +def sample_action(): + """Create a sample approval action for testing.""" + return ApprovalAction("test:read") + + +@pytest.fixture +def different_action(): + """Create a different approval action for testing.""" + return ApprovalAction("test:write") + + +@pytest.fixture +def sample_actor(): + """Create a sample approval actor for testing.""" + return ApprovalActor(id="test_user", type="user") + + +@pytest.fixture +def allow_grant(sample_action, sample_actor): + """Create an allow grant for testing.""" + return ApprovalGrant( + effect=ApprovalEffect.allow, + actions=[sample_action], + resources=["test:resource:123", "test:resource:*"], + grantee=sample_actor, + grantor=sample_actor, + ) + + +@pytest.fixture +def deny_grant(sample_action, sample_actor): + """Create a deny grant for testing.""" + return ApprovalGrant( + effect=ApprovalEffect.deny, + actions=[sample_action], + resources=["test:resource:456", "test:special:*"], + grantee=sample_actor, + grantor=sample_actor, + ) + + +class TestActionGrantedForResource: + """Tests for the _action_granted_for_resource method.""" + + def test_action_not_in_grant( + self, allow_grant, different_action, sample_actor + ): + """Test when the requested action is not in the grant's actions.""" + result = ApprovalHandler._action_granted_for_resource( + allow_grant, + different_action, + "test:resource:123", + sample_actor, + ) + assert result is None + + def test_resource_not_in_grant( + self, allow_grant, sample_action, sample_actor + ): + """Test when the requested resource is not in the grant's resources.""" + result = ApprovalHandler._action_granted_for_resource( + allow_grant, + sample_action, + "test:different:789", + sample_actor, + ) + assert result is None + + def test_exact_resource_match_allow( + self, allow_grant, sample_action, sample_actor + ): + """Test when there's an exact match with an allow grant.""" + result = ApprovalHandler._action_granted_for_resource( + allow_grant, sample_action, "test:resource:123", sample_actor + ) + assert result == ApprovalEffect.allow + + def test_exact_resource_match_deny( + self, deny_grant, sample_action, sample_actor + ): + """Test when there's an exact match with a deny grant.""" + result = ApprovalHandler._action_granted_for_resource( + deny_grant, sample_action, "test:resource:456", sample_actor + ) + assert result == ApprovalEffect.deny + + def test_wildcard_resource_match_allow( + self, allow_grant, sample_action, sample_actor + ): + """Test when there's a wildcard match with an allow grant.""" + result = ApprovalHandler._action_granted_for_resource( + allow_grant, sample_action, "test:resource:anything", sample_actor + ) + assert result == ApprovalEffect.allow + + def test_wildcard_resource_match_deny( + self, deny_grant, sample_action, sample_actor + ): + """Test when there's a wildcard match with a deny grant.""" + result = ApprovalHandler._action_granted_for_resource( + deny_grant, sample_action, "test:special:anything", sample_actor + ) + assert result == ApprovalEffect.deny + + def test_first_resource_match_is_returned(self, sample_action, sample_actor): + """Test that the first matching resource's effect is returned.""" + # Create a grant with multiple matching resources + grant = ApprovalGrant( + effect=ApprovalEffect.allow, + actions=[sample_action], + resources=["test:resource:first", "test:resource:second"], + grantee=sample_actor, + grantor=sample_actor, + ) + + result = ApprovalHandler._action_granted_for_resource( + grant, sample_action, "test:resource:first", sample_actor + ) + assert result == ApprovalEffect.allow + + # The second resource should also match if checked first + result = ApprovalHandler._action_granted_for_resource( + grant, sample_action, "test:resource:second", sample_actor + ) + assert result == ApprovalEffect.allow + + def test_grantee_match(self, sample_action, sample_actor): + """Test when the grantee matches the grant's grantee.""" + # Create a grant for a specific grantee + grant = ApprovalGrant( + effect=ApprovalEffect.allow, + actions=[sample_action], + resources=["test:resource:123"], + grantee=sample_actor, + grantor=sample_actor, + ) + + # Use the same grantee in the check + result = ApprovalHandler._action_granted_for_resource( + grant, sample_action, "test:resource:123", actor=sample_actor + ) + assert result == ApprovalEffect.allow + + def test_grantee_mismatch(self, sample_action, sample_actor): + """Test when the grantee doesn't match the grant's grantee.""" + # Create a grant for a specific grantee + grant = ApprovalGrant( + effect=ApprovalEffect.allow, + actions=[sample_action], + resources=["test:resource:123"], + grantee=sample_actor, + grantor=sample_actor, + ) + + # Use a different grantee in the check + different_actor = ApprovalActor(id="different_user", type="user") + result = ApprovalHandler._action_granted_for_resource( + grant, sample_action, "test:resource:123", actor=different_actor + ) + assert result is None + + def test_grantee_with_on_behalf_of(self, sample_action, sample_actor): + """Test when the grant's grantee has an on_behalf_of field.""" + # Create an actor acting on behalf of another actor + original_actor = ApprovalActor(id="original_user", type="user") + delegated_actor = ApprovalActor( + id="delegate_user", type="agent", on_behalf_of=original_actor + ) + + # Create a grant for a delegated actor + grant = ApprovalGrant( + effect=ApprovalEffect.allow, + actions=[sample_action], + resources=["test:resource:123"], + grantee=delegated_actor, + grantor=original_actor, + ) + + # Use the same delegated actor in the check + result = ApprovalHandler._action_granted_for_resource( + grant, sample_action, "test:resource:123", actor=delegated_actor + ) + assert result == ApprovalEffect.allow + + # Use a different delegated actor in the check + different_delegated = ApprovalActor( + id="delegate_user", + type="agent", + on_behalf_of=ApprovalActor(id="different_user", type="user"), + ) + result = ApprovalHandler._action_granted_for_resource( + grant, sample_action, "test:resource:123", actor=different_delegated + ) + assert result is None + + +class TestCheckActionOnResourceAgainstGrants: + # Tests for _check_action_on_resource_against_grants + def test_no_grants(self, sample_action, sample_actor): + """Test when there are no grants available.""" + grants = [] + result = ApprovalHandler._check_action_on_resource_against_grants( + sample_action, + "test:resource:123", + grants, + sample_actor, + ) + assert result is None + + def test_allow_grant_match(self, sample_action, sample_actor): + """Test when an allow grant matches the action and resource.""" + allow_grant = ApprovalGrant( + effect=ApprovalEffect.allow, + actions=[sample_action], + resources=["test:resource:123"], + grantee=sample_actor, + grantor=sample_actor, + ) + grants = [allow_grant] + result = ApprovalHandler._check_action_on_resource_against_grants( + sample_action, + "test:resource:123", + grants, + sample_actor, + ) + assert result == ApprovalEffect.allow + + def test_deny_grant_match(self, sample_action, sample_actor): + """Test when a deny grant matches the action and resource.""" + deny_grant = ApprovalGrant( + effect=ApprovalEffect.deny, + actions=[sample_action], + resources=["test:resource:123"], + grantee=sample_actor, + grantor=sample_actor, + ) + grants = [deny_grant] + result = ApprovalHandler._check_action_on_resource_against_grants( + sample_action, "test:resource:123", grants, sample_actor + ) + assert result == ApprovalEffect.deny + + def test_deny_takes_precedence(self, sample_action, sample_actor): + """Test that deny grants take precedence over allow grants.""" + allow_grant = ApprovalGrant( + effect=ApprovalEffect.allow, + actions=[sample_action], + resources=["test:resource:123"], + grantee=sample_actor, + grantor=sample_actor, + ) + deny_grant = ApprovalGrant( + effect=ApprovalEffect.deny, + actions=[sample_action], + resources=["test:resource:123"], + grantee=sample_actor, + grantor=sample_actor, + ) + # Note that deny_grant is after allow_grant in the list + grants = [allow_grant, deny_grant] + result = ApprovalHandler._check_action_on_resource_against_grants( + sample_action, "test:resource:123", grants, sample_actor + ) + assert result == ApprovalEffect.deny + + def test_first_match_returned( + self, sample_action, different_action, sample_actor + ): + """Test that the first matching grant's effect is returned.""" + allow_grant1 = ApprovalGrant( + effect=ApprovalEffect.allow, + actions=[sample_action], + resources=["test:resource:123"], + grantee=sample_actor, + grantor=sample_actor, + ) + allow_grant2 = ApprovalGrant( + effect=ApprovalEffect.allow, + actions=[sample_action], + resources=["test:resource:456"], + grantee=sample_actor, + grantor=sample_actor, + ) + grants = [allow_grant1, allow_grant2] + result = ApprovalHandler._check_action_on_resource_against_grants( + sample_action, "test:resource:123", grants, sample_actor + ) + assert result == ApprovalEffect.allow + + def test_no_match(self, sample_action, different_action, sample_actor): + """Test when no grants match the action and resource.""" + allow_grant = ApprovalGrant( + effect=ApprovalEffect.allow, + actions=[different_action], # Different action + resources=["test:resource:123"], + grantee=sample_actor, + grantor=sample_actor, + ) + grants = [allow_grant] + result = ApprovalHandler._check_action_on_resource_against_grants( + sample_action, "test:resource:123", grants, sample_actor + ) + assert result is None + + def test_wildcard_resource(self, sample_action, sample_actor): + """Test with a wildcard resource in a grant.""" + allow_grant = ApprovalGrant( + effect=ApprovalEffect.allow, + actions=[sample_action], + resources=["test:resource:*"], # Wildcard resource + grantee=sample_actor, + grantor=sample_actor, + ) + grants = [allow_grant] + result = ApprovalHandler._check_action_on_resource_against_grants( + sample_action, "test:resource:123", grants, sample_actor + ) + assert result == ApprovalEffect.allow + + def test_with_grantee(self, sample_action, sample_actor): + """Test that the grantee is correctly passed to _action_granted_for_resource.""" + # Create a grant for a specific grantee + grant = ApprovalGrant( + effect=ApprovalEffect.allow, + actions=[sample_action], + resources=["test:resource:123"], + grantee=sample_actor, + grantor=sample_actor, + ) + grants = [grant] + + # Test with the matching grantee + result = ApprovalHandler._check_action_on_resource_against_grants( + sample_action, "test:resource:123", grants, actor=sample_actor + ) + assert result == ApprovalEffect.allow + + # Test with a different grantee + different_actor = ApprovalActor(id="different_user", type="user") + result = ApprovalHandler._check_action_on_resource_against_grants( + sample_action, "test:resource:123", grants, actor=different_actor + ) + assert result is None + + +def create_tool_policy(tool_name, actions, static_resources): + return FunctionToolPolicy( + policy_name=FunctionToolPolicy.format_name(tool_name=tool_name), + actions=actions, + resource_mappers=lambda args: static_resources, + ) + + +class TestGetPendingChallenges: + """Tests for the get_pending_challenges method.""" + + @pytest.fixture + def mock_state(self): + """Create a mock state for testing.""" + return State(value={}, delta={}) + + @pytest.fixture + def sample_action(self): + """Create a sample approval action for testing.""" + return ApprovalAction("test:read") + + @pytest.fixture + def sample_actor(self): + """Create a sample approval actor for testing.""" + return ApprovalActor(id="test_user", type="user") + + @pytest.fixture + def function_call_id(self): + """Create a function call ID for testing.""" + return "test_function_id" + + @pytest.fixture + def allow_grant(self, sample_action, sample_actor): + """Create an allow grant for testing.""" + function_call_actor = ApprovalActor( + id="tool:test_tool:test_function_id", + type="tool", + on_behalf_of=ApprovalActor( + id="test_session", type="agent", on_behalf_of=sample_actor + ), + ) + + return ApprovalGrant( + effect=ApprovalEffect.allow, + actions=[sample_action], + resources=["test:resource:123"], + grantee=function_call_actor, + grantor=sample_actor, + ) + + @pytest.fixture + def deny_grant(self, sample_action, sample_actor): + """Create a deny grant for testing.""" + function_call_actor = ApprovalActor( + id="tool:test_tool:test_function_id", + type="tool", + on_behalf_of=ApprovalActor( + id="test_session", type="agent", on_behalf_of=sample_actor + ), + ) + + return ApprovalGrant( + effect=ApprovalEffect.deny, + actions=[sample_action], + resources=["test:resource:456"], + grantee=function_call_actor, + grantor=sample_actor, + ) + + def test_allow_when_no_policies(self, mock_state, function_call_id): + """Test that approval is granted when there are no policies.""" + approval_request = ApprovalHandler._get_pending_challenges( + mock_state, + FunctionCall( + name="test_tool", + args={"param": "value"}, + id=function_call_id, + ), + user_id="test_user", + session_id="test_session", + ) + + assert not approval_request + + def test_challenge_when_no_grants( + self, mock_state, sample_action, function_call_id + ): + """Test that a challenge is created when there are no grants.""" + # Create a mock policy with one action/resource pair + policy = create_tool_policy( + "test_tool", [sample_action], ["test:resource:123"] + ) + register_tool_policy(policy) + + approval_request = ApprovalHandler._get_pending_challenges( + mock_state, + FunctionCall( + name="test_tool", + args={"param": "value"}, + id=function_call_id, + ), + user_id="test_user", + session_id="test_session", + ) + + assert len(approval_request.challenges) == 1 + assert sample_action in approval_request.challenges[0].actions + assert "test:resource:123" in approval_request.challenges[0].resources + + def test_allow_when_all_actions_granted( + self, sample_action, allow_grant, function_call_id, mock_state + ): + """Test that approval is granted when all actions are allowed.""" + # Add grant to state + mock_state["approvals__grants"] = [allow_grant] + + # Create a mock policy with one action/resource pair that matches the grant + policy = create_tool_policy( + "test_tool", [sample_action], ["test:resource:123"] + ) + register_tool_policy(policy) + approval_request = ApprovalHandler._get_pending_challenges( + mock_state, + FunctionCall( + name="test_tool", + args={"param": "value"}, + id=function_call_id, + ), + user_id="test_user", + session_id="test_session", + ) + + assert not approval_request + + def test_deny_when_any_action_denied( + self, + sample_action, + mock_state, + allow_grant, + deny_grant, + function_call_id, + ): + """Test that approval is denied when any action is denied.""" + # Add grants to state + mock_state["approvals__grants"] = [allow_grant, deny_grant] + + # Create a mock policy with two action/resource pairs, one allowed and one denied + policy = create_tool_policy( + "test_tool", [sample_action], ["test:resource:123", "test:resource:456"] + ) + register_tool_policy(policy) + with pytest.raises(ApprovalDenied) as exc_info: + ApprovalHandler._get_pending_challenges( + mock_state, + FunctionCall( + name="test_tool", + args={"param": "value"}, + id=function_call_id, + ), + user_id="test_user", + session_id="test_session", + ) + assert exc_info.value.denied_challenges + + def test_challenge_when_some_actions_not_granted( + self, + sample_action, + mock_state, + allow_grant, + function_call_id, + ): + """Test that a challenge is created when some actions are not granted.""" + # Add grant to state + mock_state["approvals__grants"] = [allow_grant] + + # Create a different action that is not granted + different_action = ApprovalAction("test:write") + + # Create a mock policy with two action/resource pairs, one allowed and one not + policy = create_tool_policy( + "test_tool", [sample_action, different_action], ["test:resource:123"] + ) + register_tool_policy(policy) + approval_request = ApprovalHandler._get_pending_challenges( + mock_state, + FunctionCall( + name="test_tool", + args={"param": "value"}, + id=function_call_id, + ), + user_id="test_user", + session_id="test_session", + ) + + assert len(approval_request.challenges) == 1 + assert different_action in approval_request.challenges[0].actions + + def test_multiple_policies( + self, + mock_state, + sample_action, + allow_grant, + function_call_id, + ): + """Test handling multiple policies.""" + # Add grant to state + mock_state["approvals__grants"] = [allow_grant] + + # Create a different action that is not granted + different_action = ApprovalAction("test:write") + + # Create two mock policies with different action/resource pairs + policy1 = create_tool_policy( + "test_tool", [sample_action, different_action], ["test:resource:123"] + ) + policy2 = create_tool_policy( + "test_tool", [sample_action], ["test:resource:789"] + ) + register_tool_policy(policy1) + register_tool_policy(policy2) + approval_request = ApprovalHandler._get_pending_challenges( + mock_state, + FunctionCall( + name="test_tool", + args={"param": "value"}, + id=function_call_id, + ), + user_id="test_user", + session_id="test_session", + ) + + assert len(approval_request.challenges) == 2 + assert approval_request.challenges[0].actions == [different_action] + assert approval_request.challenges[0].resources == ["test:resource:123"] + assert approval_request.challenges[1].actions == [sample_action] + assert approval_request.challenges[1].resources == ["test:resource:789"] + + def test_with_wildcard_resources( + self, mock_state, sample_action, function_call_id + ): + """Test with grants that use wildcard resources.""" + # Create a wildcard grant + function_call_actor = ApprovalActor( + id="tool:test_tool:test_function_id", + type="tool", + on_behalf_of=ApprovalActor( + id="test_session", + type="agent", + on_behalf_of=ApprovalActor(id="test_user", type="user"), + ), + ) + + wildcard_grant = ApprovalGrant( + effect=ApprovalEffect.allow, + actions=[sample_action], + resources=["test:resource:*"], # Wildcard resource + grantee=function_call_actor, + grantor=ApprovalActor(id="test_user", type="user"), + ) + + # Add grant to state + mock_state["approvals__grants"] = [wildcard_grant] + + # Create a mock policy with a specific resource that should match the wildcard + policy = create_tool_policy( + "test_tool", [sample_action], ["test:resource:specific"] + ) + register_tool_policy(policy) + approval_request = ApprovalHandler._get_pending_challenges( + mock_state, + FunctionCall( + name="test_tool", + args={"param": "value"}, + id=function_call_id, + ), + user_id="test_user", + session_id="test_session", + ) + + assert not approval_request + + @pytest.mark.parametrize( + "function_call_actor_id", + [ + "tool:test_tool:test_function_id", + "tool:test_tool:*", + "tool:*:*", + "*:*:*", + "*", + ], + ) + def test_with_wildcard_grantee( + self, mock_state, sample_action, function_call_actor_id + ): + """Test with grants that use wildcard resources.""" + # Create a wildcard grant + function_call_actor = ApprovalActor( + id=function_call_actor_id, + type="tool", + on_behalf_of=ApprovalActor( + id="test_session", + type="agent", + on_behalf_of=ApprovalActor(id="test_user", type="user"), + ), + ) + + wildcard_grant = ApprovalGrant( + effect=ApprovalEffect.allow, + actions=[sample_action], + resources=["test:resource:*"], # Wildcard resource + grantee=function_call_actor, + grantor=ApprovalActor(id="test_user", type="user"), + ) + + # Add grant to state + mock_state["approvals__grants"] = [wildcard_grant] + + # Create a mock policy with a specific resource that should match the wildcard + policy = create_tool_policy( + "test_tool", [sample_action], ["test:resource:specific"] + ) + register_tool_policy(policy) + approval_request = ApprovalHandler._get_pending_challenges( + mock_state, + FunctionCall( + name="test_tool", + args={"param": "value"}, + id="test_function_id", + ), + user_id="test_user", + session_id="test_session", + ) + + assert not approval_request diff --git a/tests/unittests/approval/test_approval_policy.py b/tests/unittests/approval/test_approval_policy.py new file mode 100644 index 0000000000..79590b2b2d --- /dev/null +++ b/tests/unittests/approval/test_approval_policy.py @@ -0,0 +1,227 @@ +from crewai.tools import tool as crew_ai_tool +from langchain_core.tools import tool as lg_tool + +from google.adk import Agent +from google.adk.approval.approval_grant import ApprovalAction +from google.adk.approval.approval_policy import ApprovalPolicyRegistry, FunctionToolPolicy, resource_parameters, tool_policy +from google.adk.tools import FunctionTool +from google.adk.tools.agent_tool import AgentTool, ApprovedAgentTool +from google.adk.tools.crewai_tool import CrewaiTool +from google.adk.tools.langchain_tool import LangchainTool +from tests.unittests import utils +from tests.unittests.testing_utils import MockModel + + +def test_tool_policy__simple_policy(tmp_path): + LOCAL_FILE_NAMESPACE = "tool:local_file" + LOCAL_FILE_READ_PERMISSION = ApprovalAction(f"{LOCAL_FILE_NAMESPACE}:read") + + @tool_policy( + actions=[LOCAL_FILE_READ_PERMISSION], + resources=resource_parameters(LOCAL_FILE_NAMESPACE, ["path"]), + ) + def read_file(*, path: str): + try: + with open(path, "r") as f: + return f.read() + except Exception as e: + return str(e) + + function_tool = FunctionTool(read_file) + assert function_tool.name == "read_file" + policies = ApprovalPolicyRegistry.get_tool_policies(function_tool.name) + assert len(policies) == 1 + policy: FunctionToolPolicy = policies[0] + assert policy.tool_name == "read_file" + + assert len(policy.actions) == 1 + action: ApprovalAction = policy.actions[0] + assert action == "tool:local_file:read" + + path = str(tmp_path / "file-to-read.txt") + args = dict(path=path) + assert policy.get_resources(args) == [f"tool:local_file:{path}"] + + +def test_tool_policy__policy_with_two_actions(tmp_path): + LOCAL_FILE_NAMESPACE = "tool:local_file" + LOCAL_FILE_READ_ACTION = ApprovalAction(f"{LOCAL_FILE_NAMESPACE}:read") + LOCAL_FILE_WRITE_ACTION = ApprovalAction(f"{LOCAL_FILE_NAMESPACE}:write") + + @tool_policy( + actions=[LOCAL_FILE_READ_ACTION, LOCAL_FILE_WRITE_ACTION], + resources=resource_parameters(LOCAL_FILE_NAMESPACE, ["path"]), + ) + def str_replace(path: str, old_str: str, new_str: str): + try: + with open(path, "r") as f: + content = f.read() + new_content = content.replace(old_str, new_str) + if new_content != content: + with open(path, "w") as f: + return f.write(new_content) + else: + return "Error: could not find string to replace" + except Exception as e: + return str(e) + + function_tool = FunctionTool(str_replace) + assert function_tool.name == "str_replace" + policies = ApprovalPolicyRegistry.get_tool_policies(function_tool.name) + assert len(policies) == 1 + policy = policies[0] + assert policy.tool_name == "str_replace" + + assert len(policy.actions) == 2 + read_action, write_action = policy.actions + assert read_action == "tool:local_file:read" + assert write_action == "tool:local_file:write" + + path = str(tmp_path / "file-to-read.txt") + args = dict(path=path, old_str="something", new_str="something else") + assert policy.get_resources(args) == [f"tool:local_file:{path}"] + + +def test_tool_policy__tool_with_two_policies(tmp_path): + GIT_NAMESPACE = "tool:git" + GIT_READ = ApprovalAction(f"{GIT_NAMESPACE}:read") + GITLAB_NAMESPACE = "tool:gitlab" + GITLAB_READ = ApprovalAction(f"{GITLAB_NAMESPACE}:read") + + @tool_policy( + actions=[GIT_READ], + resources=lambda args: [ + f"{GIT_NAMESPACE}:{args['repo_path']}", + ], + ) + @tool_policy( + actions=[GITLAB_READ], + resources=lambda args: [ + f"{GITLAB_NAMESPACE}:*", + ], + ) + def gitlab_get_mr_for_repo(repo_path: str): + ... + + function_tool = FunctionTool(gitlab_get_mr_for_repo) + assert function_tool.name == "gitlab_get_mr_for_repo" + policies = ApprovalPolicyRegistry.get_tool_policies(function_tool.name) + assert len(policies) == 2 + gitlab_policy, git_policy = policies + assert git_policy.tool_name == "gitlab_get_mr_for_repo" + assert gitlab_policy.tool_name == "gitlab_get_mr_for_repo" + + assert len(git_policy.actions) == 1 + assert git_policy.actions[0] == "tool:git:read" + assert gitlab_policy.actions[0] == "tool:gitlab:read" + + repo_path = str(tmp_path / "some-repo") + args = dict(repo_path=repo_path) + assert git_policy.get_resources(args) == [f"tool:git:{repo_path}"] + assert gitlab_policy.get_resources(args) == [f"tool:gitlab:*"] + + +def test_tool_policy__agent_tool(): + agent = Agent(name="useful_agent", model=MockModel.create(responses=[])) + agent_tool = ApprovedAgentTool(agent) + assert agent_tool.name == "useful_agent" + policies = ApprovalPolicyRegistry.get_tool_policies(agent_tool.name) + assert len(policies) == 1 + policy = policies[0] + assert policy.tool_name == "useful_agent" + + assert len(policy.actions) == 1 + assert policy.actions[0] == "tool:agent:use" + assert policy.get_resources({}) == [f"tool:agents:useful_agent"] + + +def test_tool_policy__langchain_tool(tmp_path): + GIT_NAMESPACE = "tool:git" + GIT_READ = ApprovalAction(f"{GIT_NAMESPACE}:read") + GITLAB_NAMESPACE = "tool:gitlab" + GITLAB_READ = ApprovalAction(f"{GITLAB_NAMESPACE}:read") + + @lg_tool + def gitlab_get_mr_for_repo(repo_path: str): + """Tool to get gitlab Merge Request for a Repo""" + ... + + langchain_tool = LangchainTool( + gitlab_get_mr_for_repo, + policies=[ + FunctionToolPolicy( + actions=[GIT_READ], + resource_mappers=lambda args: [ + f"{GIT_NAMESPACE}:{args['repo_path']}", + ], + ), + FunctionToolPolicy( + actions=[GITLAB_READ], + resource_mappers=lambda args: [ + f"{GITLAB_NAMESPACE}:*", + ], + ), + ], + ) + assert langchain_tool.name == "gitlab_get_mr_for_repo" + policies = ApprovalPolicyRegistry.get_tool_policies(langchain_tool.name) + assert len(policies) == 2 + git_policy, gitlab_policy = policies + assert git_policy.tool_name == "gitlab_get_mr_for_repo" + assert gitlab_policy.tool_name == "gitlab_get_mr_for_repo" + + assert len(git_policy.actions) == 1 + assert git_policy.actions[0] == "tool:git:read" + assert gitlab_policy.actions[0] == "tool:gitlab:read" + + repo_path = str(tmp_path / "some-repo") + args = dict(repo_path=repo_path) + assert git_policy.get_resources(args) == [f"tool:git:{repo_path}"] + assert gitlab_policy.get_resources(args) == [f"tool:gitlab:*"] + + +def test_tool_policy__crewai_tool(tmp_path): + GIT_NAMESPACE = "tool:git" + GIT_READ = ApprovalAction(f"{GIT_NAMESPACE}:read") + GITLAB_NAMESPACE = "tool:gitlab" + GITLAB_READ = ApprovalAction(f"{GITLAB_NAMESPACE}:read") + + @crew_ai_tool + def gitlab_get_mr_for_repo(repo_path: str): + """Tool to get gitlab Merge Request for a Repo""" + ... + + crewai_tool = CrewaiTool( + gitlab_get_mr_for_repo, + name="", + description="", + policies=[ + FunctionToolPolicy( + actions=[GIT_READ], + resource_mappers=lambda args: [ + f"{GIT_NAMESPACE}:{args['repo_path']}", + ], + ), + FunctionToolPolicy( + actions=[GITLAB_READ], + resource_mappers=lambda args: [ + f"{GITLAB_NAMESPACE}:*", + ], + ), + ], + ) + assert crewai_tool.name == "gitlab_get_mr_for_repo" + policies = ApprovalPolicyRegistry.get_tool_policies(crewai_tool.name) + assert len(policies) == 2 + git_policy, gitlab_policy = policies + assert git_policy.tool_name == "gitlab_get_mr_for_repo" + assert gitlab_policy.tool_name == "gitlab_get_mr_for_repo" + + assert len(git_policy.actions) == 1 + assert git_policy.actions[0] == "tool:git:read" + assert gitlab_policy.actions[0] == "tool:gitlab:read" + + repo_path = str(tmp_path / "some-repo") + args = dict(repo_path=repo_path) + assert git_policy.get_resources(args) == [f"tool:git:{repo_path}"] + assert gitlab_policy.get_resources(args) == [f"tool:gitlab:*"] diff --git a/tests/unittests/approval/test_approval_preprocessor.py b/tests/unittests/approval/test_approval_preprocessor.py new file mode 100644 index 0000000000..3a17cc67a0 --- /dev/null +++ b/tests/unittests/approval/test_approval_preprocessor.py @@ -0,0 +1,140 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest + +from google.genai import types + +from google.adk import Agent +from google.adk.approval.approval_grant import ApprovalAction, ApprovalActor, ApprovalGrant +from google.adk.approval.approval_policy import TOOL_NAMESPACE, tool_policy +from google.adk.approval.approval_request import ApprovalResponse, FunctionCallStatus +from google.adk.approval.approval_request_processor import REQUEST_APPROVAL_FUNCTION_CALL_NAME +from google.adk.approval.approval_request_processor import request_processor +from google.adk.models.llm_request import LlmRequest + +from .. import utils +from ..flows.llm_flows.test_functions_sequential import function_call +from ..testing_utils import InMemoryRunner, MockModel + + +@pytest.mark.asyncio +async def test_request_processor__no_approvals(): + @tool_policy( + actions=[ApprovalAction("mutate_numbers")], + resources=lambda args: [f"{TOOL_NAMESPACE}:integers"], + ) + def increase_by_one(x: int) -> int: + return x + 1 + + agent = Agent( + name="root_agent", + model=MockModel.create(responses=[]), + tools=[increase_by_one], + ) + runner = InMemoryRunner(agent) + session = runner.session + + function_call_to_approve = types.FunctionCall( + id="adk-tool-id-123", + name="increase_by_one", + args={"x": 1}, + ) + grant = ApprovalGrant( + effect="allow", + actions=[ApprovalAction("mutate_numbers")], + resources=[f"{TOOL_NAMESPACE}:integers"], + grantee=ApprovalActor( + type="tool", + id=f"tool:increase_by_one:{function_call_to_approve.id}", + on_behalf_of=ApprovalActor( + type="agent", + id=runner.session.id, + on_behalf_of=ApprovalActor(type="user", id="test_user"), + ), + ), + grantor=ApprovalActor(type="user", id="test_user"), + ) + new_message = types.Content( + role="user", + parts=[ + types.Part.from_function_response( + name=REQUEST_APPROVAL_FUNCTION_CALL_NAME, + response=ApprovalResponse( + grants=[grant], + ).model_dump(), + ) + ], + ) + invocation_context = runner.runner._new_invocation_context( + session, new_message=new_message + ) + await runner.runner._append_new_message_to_session( + session, + new_message, + invocation_context, + ) + assert invocation_context.session.state.get("approvals__grants", []) == [] + invocation_context.session.state["approvals__grants"] = [] + invocation_context.session.state["approvals__suspended_function_calls"] = [ + FunctionCallStatus( + function_call=function_call_to_approve, + status="suspended", + sequence=0, + ) + ] + events = [ + event + async for event in request_processor.run_async( + invocation_context, llm_request=LlmRequest() + ) + ] + state = {**invocation_context.session.state} + for event in events: + if event.actions.state_delta: + state = {**state, **event.actions.state_delta} + + assert state["approvals__grants"] == [grant.model_dump(mode="json")] + + assert len(events) == 2 + grants_state_delta_event, response_event = events + + assert not grants_state_delta_event.content + assert grants_state_delta_event.actions.state_delta == { + "approvals__grants": [grant.model_dump(mode="json")], + } + + assert len(response_event.content.parts) == 1 + assert ( + response_event.content.parts[0].function_response.id + == function_call_to_approve.id + ) + assert ( + response_event.content.parts[0].function_response.name + == "increase_by_one" + ) + assert response_event.content.parts[0].function_response.response == { + "result": 2 + } + assert state["approvals__suspended_function_calls"] == [{ + "function_call": { + "args": { + "x": 1, + }, + "id": "adk-tool-id-123", + "name": "increase_by_one", + }, + "sequence": 1, + "status": "resumed", + }] diff --git a/tests/unittests/approval/utils.py b/tests/unittests/approval/utils.py new file mode 100644 index 0000000000..41aaaed5ab --- /dev/null +++ b/tests/unittests/approval/utils.py @@ -0,0 +1,57 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Utilities for approval tests.""" + +from typing import Dict, List, Optional, Union + +from google.adk.events.event import Event +from google.adk.events.event_actions import EventActions +from google.adk.sessions.session import Session +from google.adk.sessions.session_service import InMemorySessionService +from google.adk.sessions.state import State + + +def create_test_session(events=None, state_data=None): + """Creates a test session with the given events and state. + + Args: + events: A list of events to add to the session. + state_data: A dictionary of state data to add to the session. + + Returns: + A Session object with the given events and state. + """ + session_service = InMemorySessionService() + + # Create state with the given data + state = State() + if state_data: + for key, value in state_data.items(): + state[key] = value + + # Create a session + session = session_service.create_session( + app_name="test_app", + user_id="test_user", + state=state + ) + + # Add events + if events: + for event in events: + # If the event has state_delta, it needs to be processed by append_event + session_service.append_event(session=session, event=event) + + return session diff --git a/tests/unittests/conftest.py b/tests/unittests/conftest.py index ad204005eb..baa686195c 100644 --- a/tests/unittests/conftest.py +++ b/tests/unittests/conftest.py @@ -14,11 +14,14 @@ import os +import pytest from pytest import fixture from pytest import FixtureRequest from pytest import hookimpl from pytest import Metafunc +from google.adk.approval.approval_policy import ApprovalPolicyRegistry + _ENV_VARS = { 'GOOGLE_API_KEY': 'fake_google_api_key', 'GOOGLE_CLOUD_PROJECT': 'fake_google_cloud_project', @@ -71,3 +74,18 @@ def _is_explicitly_marked(mark_name: str, metafunc: Metafunc) -> bool: if mark.name == 'parametrize' and mark.args[0] == mark_name: return True return False + + +@pytest.fixture(autouse=True) +def isolate_policy_registry(): + backup_policies = ApprovalPolicyRegistry.tool_policies[:] + ApprovalPolicyRegistry.tool_policies.clear() + yield backup_policies + ApprovalPolicyRegistry.tool_policies.clear() + ApprovalPolicyRegistry.tool_policies.extend(backup_policies) + + +@pytest.fixture() +def use_system_policies(isolate_policy_registry): + ApprovalPolicyRegistry.tool_policies.extend(isolate_policy_registry) + yield diff --git a/tests/unittests/flows/llm_flows/test_approval.py b/tests/unittests/flows/llm_flows/test_approval.py new file mode 100644 index 0000000000..7611ca9e80 --- /dev/null +++ b/tests/unittests/flows/llm_flows/test_approval.py @@ -0,0 +1,799 @@ +import pytest +from google.genai import types + +from google.adk import Agent +from google.adk.approval.approval_grant import ApprovalAction, ApprovalActor, ApprovalGrant +from google.adk.approval.approval_policy import TOOL_NAMESPACE, resource_parameter_map, resource_parameters, tool_policy +from google.adk.approval.approval_request import ApprovalResponse +from google.adk.approval.approval_request_processor import REQUEST_APPROVAL_FUNCTION_CALL_NAME +from ... import utils +from ...testing_utils import InMemoryRunner, MockModel + + +def test_call_approval_requested(): + responses = [ + types.Part.from_function_call(name="increase_by_one", args={"x": 1}), + "response1", + "response2", + ] + mock_model = MockModel.create(responses=responses) + + @tool_policy( + actions=[ApprovalAction("mutate_numbers")], + resources=lambda args: [f"{TOOL_NAMESPACE}:integers"], + ) + def increase_by_one(x: int) -> int: + return x + 1 + + agent = Agent(name="root_agent", model=mock_model, tools=[increase_by_one]) + runner = InMemoryRunner(agent) + preapproval_events = runner.run("test") + assert len(preapproval_events) == 4 + + state = runner.session.state + for event in preapproval_events: + if event.actions and event.actions.state_delta: + state.update(event.actions.state_delta or {}) + ( + function_call_event, + approval_request_event, + function_response_event, + model_summary_event, + ) = preapproval_events + + assert len(function_call_event.content.parts) == 1 + assert ( + function_call_event.content.parts[0].function_call.name + == "increase_by_one" + ) + + assert len(approval_request_event.content.parts) == 1 + assert ( + approval_request_event.content.parts[0].function_call.name + == REQUEST_APPROVAL_FUNCTION_CALL_NAME + ) + + assert ( + function_call_event.content.parts[0].function_call.id + != approval_request_event.content.parts[0].function_call.id + ) + assert ( + function_call_event.content.parts[0].function_call.id + == approval_request_event.content.parts[0].function_call.args[ + "function_call" + ]["id"] + ) + + assert len(function_response_event.content.parts) == 1 + assert ( + function_response_event.content.parts[0].function_response.response[ + "status" + ] + == "approval_requested" + ) + assert ( + function_call_event.content.parts[0].function_call.id + == function_response_event.content.parts[0].function_response.id + ) + + assert len(model_summary_event.content.parts) == 1 + assert model_summary_event.content.parts[0].text == "response1" + + assert ( + function_call_event.content.parts[0].function_call.id + in runner.session.state["approvals__suspended_function_calls"][0][ + "function_call" + ]["id"] + ) + + +@pytest.mark.asyncio +async def test_call_function_resumed(): + responses = [ + types.Part.from_function_call(name="increase_by_one", args={"x": 1}), + "response1", + "response2", + "response3", + ] + mock_model = MockModel.create(responses=responses) + + @tool_policy( + actions=[ApprovalAction("mutate_numbers")], + resources=lambda args: [f"{TOOL_NAMESPACE}:integers"], + ) + def increase_by_one(x: int) -> int: + return x + 1 + + agent = Agent(name="root_agent", model=mock_model, tools=[increase_by_one]) + runner = InMemoryRunner(agent) + preapproval_events = runner.run("test") + + ( + function_call_event, + approval_request_event, + function_response_event, + model_summary_event, + ) = preapproval_events + + assert ( + function_call_event.content.parts[0].function_call.name + == "increase_by_one" + ) + assert ( + approval_request_event.content.parts[0].function_call.name + == REQUEST_APPROVAL_FUNCTION_CALL_NAME + ) + assert ( + function_response_event.content.parts[0].function_response.response[ + "status" + ] + == "approval_requested" + ) + assert model_summary_event.content.parts[0].text == "response1" + + assert ( + function_call_event.content.parts[0].function_call.id + in runner.session.state["approvals__suspended_function_calls"][0][ + "function_call" + ]["id"] + ) + + function_call_to_approve = approval_request_event.content.parts[ + 0 + ].function_call.args["function_call"] + grant = ApprovalGrant( + effect="allow", + actions=[ApprovalAction("mutate_numbers")], + resources=[f"{TOOL_NAMESPACE}:integers"], + grantee=ApprovalActor( + type="tool", + id=f"tool:increase_by_one:{function_call_to_approve['id']}", + on_behalf_of=ApprovalActor( + type="agent", + id=runner.session_id, + on_behalf_of=ApprovalActor(type="user", id="test_user"), + ), + ), + grantor=ApprovalActor(type="user", id="test_user"), + ) + postapproval_events = await runner.run_async( + types.Content( + role="user", + parts=[ + types.Part.from_function_response( + name=REQUEST_APPROVAL_FUNCTION_CALL_NAME, + response=ApprovalResponse( + grants=[grant], + ).model_dump(), + ) + ], + ) + ) + assert len(postapproval_events) == 3 + state_delta_event, function_response_event, summary_event = ( + postapproval_events + ) + assert len(function_response_event.content.parts) == 1 + assert ( + function_response_event.content.parts[0].function_response.id + == function_call_to_approve["id"] + ) + assert ( + function_response_event.content.parts[0].function_response.name + == "increase_by_one" + ) + assert function_response_event.content.parts[ + 0 + ].function_response.response == {"result": 2} + + assert len(summary_event.content.parts) == 1 + assert summary_event.content.parts[0].text == "response2" + + assert runner.session.state["approvals__suspended_function_calls"] == [{ + "function_call": { + "args": {"x": 1}, + "id": function_call_event.content.parts[0].function_call.id, + "name": "increase_by_one", + }, + "sequence": 1, + "status": "resumed", + }] + + reapproval_events = await runner.run_async( + types.Content( + role="user", + parts=[ + types.Part.from_function_response( + name=REQUEST_APPROVAL_FUNCTION_CALL_NAME, + response=ApprovalResponse( + grants=[grant], + ).model_dump(), + ) + ], + ) + ) + assert reapproval_events == [] + + +@pytest.mark.asyncio +async def test_partially_approving_challenges(): + responses = [ + types.Part.from_function_call(name="increase_by_one", args={"x": 1}), + "response1", + "response2", + "response3", + ] + mock_model = MockModel.create(responses=responses) + + @tool_policy( + actions=[ApprovalAction("mutate_numbers")], + resources=lambda args: [f"{TOOL_NAMESPACE}:integers"], + ) + @tool_policy( + actions=[ApprovalAction("read_numbers")], + resources=lambda args: [f"{TOOL_NAMESPACE}:integers"], + ) + def increase_by_one(x: int) -> int: + return x + 1 + + agent = Agent(name="root_agent", model=mock_model, tools=[increase_by_one]) + runner = InMemoryRunner(agent) + preapproval_events = await runner.run_async( + types.Content(role="user", parts=[types.Part.from_text(text="test")]) + ) + + ( + pre_function_call_event, + pre_approval_request_event, + pre_function_response_event, + pre_model_summary_event, + ) = preapproval_events + + assert ( + pre_function_call_event.content.parts[0].function_call.name + == "increase_by_one" + ) + assert ( + pre_approval_request_event.content.parts[0].function_call.name + == REQUEST_APPROVAL_FUNCTION_CALL_NAME + ) + assert ( + pre_function_response_event.content.parts[0].function_response.response[ + "status" + ] + == "approval_requested" + ) + assert pre_model_summary_event.content.parts[0].text == "response1" + + assert len(runner.session.state["approvals__suspended_function_calls"]) == 1 + assert ( + pre_function_call_event.content.parts[0].function_call.id + in runner.session.state["approvals__suspended_function_calls"][0][ + "function_call" + ][ + "id" + ] # TODO filter to suspended + ) + + function_call_to_approve = pre_approval_request_event.content.parts[ + 0 + ].function_call.args["function_call"] + read_grant = ApprovalGrant( + effect="allow", + actions=[ApprovalAction("read_numbers")], + resources=[f"{TOOL_NAMESPACE}:integers"], + grantee=ApprovalActor( + type="tool", + id=f"tool:increase_by_one:{function_call_to_approve['id']}", + on_behalf_of=ApprovalActor( + type="agent", + id=runner.session_id, + on_behalf_of=ApprovalActor(type="user", id="test_user"), + ), + ), + grantor=ApprovalActor(type="user", id="test_user"), + ) + partial_approval_events = await runner.run_async( + types.Content( + role="user", + parts=[ + types.Part.from_function_response( + name=REQUEST_APPROVAL_FUNCTION_CALL_NAME, + response=ApprovalResponse( + grants=[read_grant], + ).model_dump(), + ) + ], + ) + ) + ( + partial_grants_state_delta_event, + partial_function_response_event, + partial_summary_event, + ) = partial_approval_events + assert len(partial_function_response_event.content.parts) == 1 + assert ( + partial_function_response_event.content.parts[0].function_response.id + == function_call_to_approve["id"] + ) + assert ( + partial_function_response_event.content.parts[0].function_response.name + == "increase_by_one" + ) + assert ( + partial_function_response_event.content.parts[ + 0 + ].function_response.response["status"] + == "approval_requested" + ) + + assert len(partial_summary_event.content.parts) == 1 + assert partial_summary_event.content.parts[0].text == "response2" + + assert runner.session.state["approvals__suspended_function_calls"] == [{ + "function_call": function_call_to_approve, + "sequence": 1, + "status": "suspended", + }] + + mutate_grant = ApprovalGrant( + effect="allow", + actions=[ApprovalAction("mutate_numbers")], + resources=[f"{TOOL_NAMESPACE}:integers"], + grantee=ApprovalActor( + type="tool", + id=f"tool:increase_by_one:{function_call_to_approve['id']}", + on_behalf_of=ApprovalActor( + type="agent", + id=runner.session_id, + on_behalf_of=ApprovalActor(type="user", id="test_user"), + ), + ), + grantor=ApprovalActor(type="user", id="test_user"), + ) + postapproval_events = await runner.run_async( + types.Content( + role="user", + parts=[ + types.Part.from_function_response( + name=REQUEST_APPROVAL_FUNCTION_CALL_NAME, + response=ApprovalResponse( + grants=[mutate_grant], + ).model_dump(), + ) + ], + ) + ) + ( + post_grants_state_delta_event, + post_function_response_event, + post_summary_event, + ) = postapproval_events + assert len(post_function_response_event.content.parts) == 1 + assert ( + post_function_response_event.content.parts[0].function_response.id + == function_call_to_approve["id"] + ) + assert ( + post_function_response_event.content.parts[0].function_response.name + == "increase_by_one" + ) + assert post_function_response_event.content.parts[ + 0 + ].function_response.response == {"result": 2} + + assert len(post_summary_event.content.parts) == 1 + assert post_summary_event.content.parts[0].text == "response3" + + assert runner.session.state["approvals__suspended_function_calls"] == [{ + "function_call": function_call_to_approve, + "sequence": 1, + "status": "resumed", + }] + + reapproval_events = runner.run( + types.Content( + role="user", + parts=[ + types.Part.from_function_response( + name=REQUEST_APPROVAL_FUNCTION_CALL_NAME, + response=ApprovalResponse( + grants=[mutate_grant], + ).model_dump(), + ) + ], + ) + ) + assert reapproval_events == [] + + +async def test_parallel_call_function_resumed(): + responses = [ + [ + types.Part.from_function_call(name="increase_by_one", args={"x": 1}), + types.Part.from_function_call(name="increase_by_one", args={"x": 5}), + ], + "response1", + "response2", + "response3", + ] + mock_model = MockModel.create(responses=responses) + + @tool_policy( + actions=[ApprovalAction("mutate_numbers")], + resources=lambda args: [f"{TOOL_NAMESPACE}:integers"], + ) + def increase_by_one(x: int) -> int: + return x + 1 + + agent = Agent(name="root_agent", model=mock_model, tools=[increase_by_one]) + runner = InMemoryRunner(agent) + preapproval_events = await runner.run_async( + types.Content(role="user", parts=[types.Part.from_text(text="test")]) + ) + + ( + function_call_event, + approval_request_event, + function_response_event0, + model_summary_event, + ) = preapproval_events + + assert ( + function_call_event.content.parts[0].function_call.name + == "increase_by_one" + ) + assert ( + approval_request_event.content.parts[0].function_call.name + == REQUEST_APPROVAL_FUNCTION_CALL_NAME + ) + assert ( + function_response_event0.content.parts[0].function_response.response[ + "status" + ] + == "approval_requested" + ) + assert model_summary_event.content.parts[0].text == "response1" + + assert ( + function_call_event.content.parts[0].function_call.id + in runner.session.state["approvals__suspended_function_calls"][0][ + "function_call" + ]["id"] + ) + + function_call_to_approve0, function_call_to_approve1 = [ + part.function_call.args["function_call"] for part in approval_request_event.content.parts + ] + + grant0, grant1 = [ + ApprovalGrant( + effect="allow", + actions=[ApprovalAction("mutate_numbers")], + resources=[f"{TOOL_NAMESPACE}:integers"], + grantee=ApprovalActor( + type="tool", + id=f"tool:increase_by_one:{function_call_part.function_call.args['function_call']['id']}", + on_behalf_of=ApprovalActor( + type="agent", + id=runner.session_id, + on_behalf_of=ApprovalActor(type="user", id="test_user"), + ), + ), + grantor=ApprovalActor(type="user", id="test_user"), + ) + for function_call_part in approval_request_event.content.parts + ] + approval_events0 = await runner.run_async( + types.Content( + role="user", + parts=[ + types.Part.from_function_response( + name=REQUEST_APPROVAL_FUNCTION_CALL_NAME, + response=ApprovalResponse( + grants=[grant0], + ).model_dump(), + ) + ], + ) + ) + assert len(approval_events0) == 3 + state_delta_event0, function_response_event0, summary_event0 = ( + approval_events0 + ) + for event in approval_events0: + if event.actions and event.actions.state_delta: + runner.session.state.update(event.actions.state_delta or {}) + assert len(function_response_event0.content.parts) == 2 + assert ( + function_response_event0.content.parts[0].function_response.id + == function_call_to_approve0["id"] + ) + assert ( + function_response_event0.content.parts[0].function_response.name + == "increase_by_one" + ) + assert function_response_event0.content.parts[ + 0 + ].function_response.response == {"result": 2} + + assert len(summary_event0.content.parts) == 1 + assert summary_event0.content.parts[0].text == "response2" + + assert state_delta_event0.actions.state_delta["approvals__grants"] == [ + grant0.model_dump() + ] + + assert runner.session.state["approvals__suspended_function_calls"] == [ + { + "function_call": function_call_event.content.parts[ + 0 + ].function_call.model_dump(), + "sequence": 1, + "status": "resumed", + }, + { + "function_call": function_call_event.content.parts[ + 1 + ].function_call.model_dump(), + "sequence": 2, + "status": "suspended", + }, + ] + assert runner.session.state["approvals__grants"] == [grant0.model_dump()] + + approval_events1 = await runner.run_async( + types.Content( + role="user", + parts=[ + types.Part.from_function_response( + name=REQUEST_APPROVAL_FUNCTION_CALL_NAME, + response=ApprovalResponse( + grants=[grant1], + ).model_dump(), + ) + ], + ) + ) + state_delta_event1, function_response_event1, summary_event1 = ( + approval_events1 + ) + for event in approval_events1: + if event.actions and event.actions.state_delta: + runner.session.state.update(event.actions.state_delta or {}) + + assert runner.session.state["approvals__grants"] == [ + grant0.model_dump(), + grant1.model_dump(), + ] + assert runner.session.state["approvals__suspended_function_calls"] == [ + { + "function_call": function_call_event.content.parts[ + 1 + ].function_call.model_dump(), + "sequence": 1, + "status": "resumed", + }, + ] + + assert len(function_response_event1.content.parts) == 1 + assert ( + function_response_event1.content.parts[0].function_response.id + == function_call_to_approve1["id"] + ) + assert ( + function_response_event1.content.parts[0].function_response.name + == "increase_by_one" + ) + assert function_response_event1.content.parts[ + 0 + ].function_response.response == {"result": 6} + + assert len(summary_event1.content.parts) == 1 + assert summary_event1.content.parts[0].text == "response3" + + +async def test_sequential_call_function(): + responses = [ + types.Part.from_function_call(name="increase_by_one", args={"x": 1}), + "approval is needed", + "tool was called", + types.Part.from_function_call(name="increase_by_one", args={"x": 5}), + "approval is needed again", + "tool was called again", + ] + mock_model = MockModel.create(responses=responses) + + @tool_policy( + actions=[ApprovalAction("mutate_numbers")], + resources=lambda args: [f"{TOOL_NAMESPACE}:integers"], + ) + def increase_by_one(x: int) -> int: + return x + 1 + + agent = Agent(name="root_agent", model=mock_model, tools=[increase_by_one]) + runner = InMemoryRunner(agent) + preapproval_events = await runner.run_async( + types.Content(role="user", parts=[types.Part.from_text(text="call the function with x as 1")]) + ) + + ( + function_call_event0, + approval_request_event0, + function_response_event0, + model_summary_event0, + ) = preapproval_events + + assert ( + function_call_event0.content.parts[0].function_call.name + == "increase_by_one" + ) + assert ( + approval_request_event0.content.parts[0].function_call.name + == REQUEST_APPROVAL_FUNCTION_CALL_NAME + ) + assert ( + function_response_event0.content.parts[0].function_response.response[ + "status" + ] + == "approval_requested" + ) + assert model_summary_event0.content.parts[0].text == "approval is needed" + + assert ( + function_call_event0.content.parts[0].function_call.id + in runner.session.state["approvals__suspended_function_calls"][0][ + "function_call" + ]["id"] + ) + + function_call_to_approve0 = approval_request_event0.content.parts[ + 0 + ].function_call.args["function_call"] + grant0 = ApprovalGrant( + effect="allow", + actions=[ApprovalAction("mutate_numbers")], + resources=[f"{TOOL_NAMESPACE}:integers"], + grantee=ApprovalActor( + type="tool", + id=f"tool:increase_by_one:{approval_request_event0.content.parts[0].function_call.args['function_call']['id']}", + on_behalf_of=ApprovalActor( + type="agent", + id=runner.session_id, + on_behalf_of=ApprovalActor(type="user", id="test_user"), + ), + ), + grantor=ApprovalActor(type="user", id="test_user"), + ) + approval_events0 = await runner.run_async( + types.Content( + role="user", + parts=[ + types.Part.from_function_response( + name=REQUEST_APPROVAL_FUNCTION_CALL_NAME, + response=ApprovalResponse( + grants=[grant0], + ).model_dump(), + ) + ], + ) + ) + assert len(approval_events0) == 3 + state_delta_event0, function_response_event0, summary_event0 = ( + approval_events0 + ) + for event in approval_events0: + if event.actions and event.actions.state_delta: + runner.session.state.update(event.actions.state_delta or {}) + assert len(function_response_event0.content.parts) == 1 + assert ( + function_response_event0.content.parts[0].function_response.id + == function_call_to_approve0["id"] + ) + assert ( + function_response_event0.content.parts[0].function_response.name + == "increase_by_one" + ) + assert function_response_event0.content.parts[ + 0 + ].function_response.response == {"result": 2} + + assert len(summary_event0.content.parts) == 1 + assert summary_event0.content.parts[0].text == "tool was called" + + assert state_delta_event0.actions.state_delta["approvals__grants"] == [ + grant0.model_dump() + ] + + assert runner.session.state["approvals__suspended_function_calls"] == [ + { + "function_call": function_call_event0.content.parts[ + 0 + ].function_call.model_dump(), + "sequence": 1, + "status": "resumed", + }, + ] + assert runner.session.state["approvals__grants"] == [grant0.model_dump()] + + second_preapproval_events = await runner.run_async( + types.Content(role="user", parts=[types.Part.from_text(text="call it again with 5")]) + ) + ( + function_call_event1, + approval_request_event1, + function_response_event1, + model_summary_event1, + ) = second_preapproval_events + + function_call_to_approve1 = approval_request_event1.content.parts[ + 0 + ].function_call.args["function_call"] + + grant1 = ApprovalGrant( + effect="allow", + actions=[ApprovalAction("mutate_numbers")], + resources=[f"{TOOL_NAMESPACE}:integers"], + grantee=ApprovalActor( + type="tool", + id=f"tool:increase_by_one:{approval_request_event1.content.parts[0].function_call.args['function_call']['id']}", + on_behalf_of=ApprovalActor( + type="agent", + id=runner.session_id, + on_behalf_of=ApprovalActor(type="user", id="test_user"), + ), + ), + grantor=ApprovalActor(type="user", id="test_user"), + ) + approval_events1 = await runner.run_async( + types.Content( + role="user", + parts=[ + types.Part.from_function_response( + name=REQUEST_APPROVAL_FUNCTION_CALL_NAME, + response=ApprovalResponse( + grants=[grant1], + ).model_dump(), + ) + ], + ) + ) + state_delta_event1, function_response_event1, summary_event1 = ( + approval_events1 + ) + for event in approval_events1: + if event.actions and event.actions.state_delta: + runner.session.state.update(event.actions.state_delta or {}) + + assert runner.session.state["approvals__grants"] == [ + grant0.model_dump(), + grant1.model_dump(), + ] + assert runner.session.state["approvals__suspended_function_calls"] == [ + { + "function_call": function_call_event1.content.parts[ + 0 + ].function_call.model_dump(), + "sequence": 1, + "status": "resumed", + }, + ] + + assert len(function_response_event1.content.parts) == 1 + assert ( + function_response_event1.content.parts[0].function_response.id + == function_call_to_approve1["id"] + ) + assert ( + function_response_event1.content.parts[0].function_response.name + == "increase_by_one" + ) + assert function_response_event1.content.parts[ + 0 + ].function_response.response == {"result": 6} + + assert len(summary_event1.content.parts) == 1 + assert summary_event1.content.parts[0].text == "tool was called again" + +# TODO regression test for past tool calls being called again on second approval +# TODO validate in the tests that the message history is valid (no duplicate tool calls) diff --git a/tests/unittests/flows/llm_flows/test_async_tool_callbacks.py b/tests/unittests/flows/llm_flows/test_async_tool_callbacks.py index 35f3a811f6..950b4763cc 100644 --- a/tests/unittests/flows/llm_flows/test_async_tool_callbacks.py +++ b/tests/unittests/flows/llm_flows/test_async_tool_callbacks.py @@ -85,7 +85,7 @@ def simple_fn(**kwargs) -> Dict[str, Any]: agent=agent, user_content="" ) # Build function call event - function_call = types.FunctionCall(name=tool.name, args={}) + function_call = types.FunctionCall(name=tool.name, args={}, id="test_id") content = types.Content(parts=[types.Part(function_call=function_call)]) event = Event( invocation_id=invocation_context.invocation_id, diff --git a/tests/unittests/testing_utils.py b/tests/unittests/testing_utils.py index 139e0d5647..5f211c8351 100644 --- a/tests/unittests/testing_utils.py +++ b/tests/unittests/testing_utils.py @@ -18,6 +18,8 @@ from typing import Generator from typing import Union +import google.genai.errors + from google.adk.agents.invocation_context import InvocationContext from google.adk.agents.live_request_queue import LiveRequestQueue from google.adk.agents.llm_agent import Agent @@ -177,17 +179,14 @@ def __init__( @property def session(self) -> Session: if not self.session_id: - session = asyncio.run( - self.runner.session_service.create_session( - app_name='test_app', user_id='test_user' - ) + session = self.runner.session_service.create_session_sync( + app_name='test_app', user_id='test_user' ) + self.session_id = session.id return session - return asyncio.run( - self.runner.session_service.get_session( - app_name='test_app', user_id='test_user', session_id=self.session_id - ) + return self.runner.session_service.get_session_sync( + app_name='test_app', user_id='test_user', session_id=self.session_id ) def run(self, new_message: types.ContentUnion) -> list[Event]: @@ -199,6 +198,16 @@ def run(self, new_message: types.ContentUnion) -> list[Event]: ) ) + async def run_async(self, new_message: types.ContentUnion) -> list[Event]: + events = [] + async for event in self.runner.run_async( + user_id=self.session.user_id, + session_id=self.session.id, + new_message=new_message, + ): + events.append(event) + return events + def run_live(self, live_request_queue: LiveRequestQueue) -> list[Event]: collected_responses = [] @@ -267,6 +276,7 @@ def generate_content( self, llm_request: LlmRequest, stream: bool = False ) -> Generator[LlmResponse, None, None]: # Increasement of the index has to happen before the yield. + self.validate(llm_request) self.response_index += 1 self.requests.append(llm_request) # yield LlmResponse(content=self.responses[self.response_index]) @@ -277,10 +287,38 @@ async def generate_content_async( self, llm_request: LlmRequest, stream: bool = False ) -> AsyncGenerator[LlmResponse, None]: # Increasement of the index has to happen before the yield. + self.validate(llm_request) self.response_index += 1 self.requests.append(llm_request) yield self.responses[self.response_index] + def validate(self, llm_request: LlmRequest): + function_calls = [ + part + for content in llm_request.contents + for part in content.parts + if hasattr(part, 'function_call') and part.function_call + ] + function_responses = [ + part + for content in llm_request.contents + for part in content.parts + if hasattr(part, 'function_response') and part.function_response + ] + + if len(function_calls) != len(function_responses): + raise google.genai.errors.ClientError( + code=400, + response_json={ + 'error': { + 'code': 400, + 'message': 'Please ensure that the number of function response parts is equal to the number of function call parts of the function call turn.', 'status': 'INVALID_ARGUMENT', + "debug_function_calls": function_calls, # Not part of normal error + "debug_function_responses": function_responses, + } + } + ) + @contextlib.asynccontextmanager async def connect(self, llm_request: LlmRequest) -> BaseLlmConnection: """Creates a live connection to the LLM.""" From 4a8af9e9d6eca821b1da7858e43bf582fab6ed1f Mon Sep 17 00:00:00 2001 From: Calvin Giles Date: Wed, 4 Jun 2025 11:16:32 +1200 Subject: [PATCH 2/7] feat: Add support for vertex gemini model optimizer --- src/google/adk/models/google_llm.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/google/adk/models/google_llm.py b/src/google/adk/models/google_llm.py index 9c03b2867d..7b854f2318 100644 --- a/src/google/adk/models/google_llm.py +++ b/src/google/adk/models/google_llm.py @@ -66,6 +66,8 @@ def supported_models() -> list[str]: return [ r'gemini-.*', + # model optimizer pattern + r'model-optimizer-.*', # fine-tuned vertex endpoint pattern r'projects\/.+\/locations\/.+\/endpoints\/.+', # vertex gemini long name From 5acf1537fddb7981b697d4a82502c047d2837b99 Mon Sep 17 00:00:00 2001 From: Calvin Giles Date: Wed, 4 Jun 2025 19:49:04 +1200 Subject: [PATCH 3/7] tidy --- src/google/adk/approval/approval_policy.py | 1 + src/google/adk/flows/llm_flows/functions.py | 1 - 2 files changed, 1 insertion(+), 1 deletion(-) diff --git a/src/google/adk/approval/approval_policy.py b/src/google/adk/approval/approval_policy.py index d3799c9dc1..1a0dada26b 100644 --- a/src/google/adk/approval/approval_policy.py +++ b/src/google/adk/approval/approval_policy.py @@ -11,6 +11,7 @@ Policies determine what actions on what resources require approval before a tool can be executed. """ +from __future__ import annotations from typing import Any, Callable, Optional from pydantic import BaseModel diff --git a/src/google/adk/flows/llm_flows/functions.py b/src/google/adk/flows/llm_flows/functions.py index cca5947f55..0d802d2cfb 100644 --- a/src/google/adk/flows/llm_flows/functions.py +++ b/src/google/adk/flows/llm_flows/functions.py @@ -37,7 +37,6 @@ from ...telemetry import trace_merged_tool_calls from ...sessions import State from ...telemetry import trace_tool_call -from ...telemetry import trace_tool_response from ...telemetry import tracer from ...tools.base_tool import BaseTool from ...tools.tool_context import ToolContext From 531508ca96825189a0499c3c0b924a686a960f51 Mon Sep 17 00:00:00 2001 From: Calvin Giles Date: Wed, 4 Jun 2025 20:17:03 +1200 Subject: [PATCH 4/7] fix --- src/google/adk/flows/llm_flows/functions.py | 17 ++++++++--------- 1 file changed, 8 insertions(+), 9 deletions(-) diff --git a/src/google/adk/flows/llm_flows/functions.py b/src/google/adk/flows/llm_flows/functions.py index 0d802d2cfb..835a5165b1 100644 --- a/src/google/adk/flows/llm_flows/functions.py +++ b/src/google/adk/flows/llm_flows/functions.py @@ -250,11 +250,11 @@ async def handle_function_calls_async( if not function_response: continue - # Builds the function response event. - function_response_event = __build_response_event( - tool, function_response, tool_context, invocation_context - ) - function_response_events.append(function_response_event) + # Builds the function response event. + function_response_event = __build_response_event( + tool, function_response, tool_context, invocation_context + ) + function_response_events.append(function_response_event) if not function_response_events: return None @@ -266,10 +266,9 @@ async def handle_function_calls_async( # this is needed for debug traces of parallel calls # individual response with tool.name is traced in __build_response_event # (we drop tool.name from span name here as this is merged event) - with tracer.start_as_current_span('tool_response'): - trace_tool_response( - invocation_context=invocation_context, - event_id=merged_event.id, + with tracer.start_as_current_span('execute_tool (merged)'): + trace_merged_tool_calls( + response_event_id=merged_event.id, function_response_event=merged_event, ) return merged_event From d3100f5133b094c3f6106b4d8c19e9b7e3d064a7 Mon Sep 17 00:00:00 2001 From: Calvin Giles Date: Thu, 5 Jun 2025 12:29:10 +1200 Subject: [PATCH 5/7] fix --- src/google/adk/flows/llm_flows/functions.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/src/google/adk/flows/llm_flows/functions.py b/src/google/adk/flows/llm_flows/functions.py index 835a5165b1..cd99a1636a 100644 --- a/src/google/adk/flows/llm_flows/functions.py +++ b/src/google/adk/flows/llm_flows/functions.py @@ -254,6 +254,11 @@ async def handle_function_calls_async( function_response_event = __build_response_event( tool, function_response, tool_context, invocation_context ) + trace_tool_call( + tool=tool, + args=function_args, + function_response_event=function_response_event, + ) function_response_events.append(function_response_event) if not function_response_events: From 19ef066e5c608d4d7d67899f99b48b7f5f42e335 Mon Sep 17 00:00:00 2001 From: Calvin Giles Date: Fri, 6 Jun 2025 16:37:57 +1200 Subject: [PATCH 6/7] fix: Notify agent of failed agent transfer in function response --- .../adk/flows/llm_flows/base_llm_flow.py | 40 ++++++++++-- .../flows/llm_flows/test_agent_transfer.py | 63 ++++++++++++++++++- 2 files changed, 96 insertions(+), 7 deletions(-) diff --git a/src/google/adk/flows/llm_flows/base_llm_flow.py b/src/google/adk/flows/llm_flows/base_llm_flow.py index 5fbfab24d7..de3f6683c3 100644 --- a/src/google/adk/flows/llm_flows/base_llm_flow.py +++ b/src/google/adk/flows/llm_flows/base_llm_flow.py @@ -439,15 +439,29 @@ async def _postprocess_live( function_response_event = await functions.handle_function_calls_live( invocation_context, model_response_event, llm_request.tools_dict ) - yield function_response_event transfer_to_agent = function_response_event.actions.transfer_to_agent if transfer_to_agent: agent_to_run = self._get_agent_to_run( invocation_context, transfer_to_agent ) - async for item in agent_to_run.run_live(invocation_context): - yield item + transfer_successful = False + try: + async for item in agent_to_run.run_live(invocation_context): + if not transfer_successful: + yield function_response_event + transfer_successful = True + yield item + finally: + if not transfer_successful: + function_response_event.content.parts[ + 0 + ].function_response.response = { + 'result': f'Error transferring to {transfer_to_agent}' + } + yield function_response_event + else: + yield function_response_event async def _postprocess_run_processors_async( self, invocation_context: InvocationContext, llm_response: LlmResponse @@ -471,14 +485,28 @@ async def _postprocess_handle_function_calls_async( if auth_event: yield auth_event - yield function_response_event transfer_to_agent = function_response_event.actions.transfer_to_agent if transfer_to_agent: agent_to_run = self._get_agent_to_run( invocation_context, transfer_to_agent ) - async for event in agent_to_run.run_async(invocation_context): - yield event + transfer_successful = False + try: + async for event in agent_to_run.run_async(invocation_context): + if not transfer_successful: + yield function_response_event + transfer_successful = True + yield event + finally: + if not transfer_successful: + function_response_event.content.parts[ + 0 + ].function_response.response = { + 'result': f'Error transferring to {transfer_to_agent}' + } + yield function_response_event + else: + yield function_response_event def _get_agent_to_run( self, invocation_context: InvocationContext, agent_name: str diff --git a/tests/unittests/flows/llm_flows/test_agent_transfer.py b/tests/unittests/flows/llm_flows/test_agent_transfer.py index fe26c42a36..6148394520 100644 --- a/tests/unittests/flows/llm_flows/test_agent_transfer.py +++ b/tests/unittests/flows/llm_flows/test_agent_transfer.py @@ -11,13 +11,16 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from typing import Optional from google.adk.agents.llm_agent import Agent from google.adk.agents.loop_agent import LoopAgent +from google.adk.agents.readonly_context import ReadonlyContext from google.adk.agents.sequential_agent import SequentialAgent -from google.adk.tools import exit_loop +from google.adk.tools import BaseTool, FunctionTool, exit_loop from google.genai.types import Part +from google.adk.tools.base_toolset import BaseToolset from ... import testing_utils @@ -311,3 +314,61 @@ def test_auto_to_loop(): assert testing_utils.simplify_events(runner.run('test2')) == [ ('root_agent', 'response5'), ] + + +class FakeError(Exception): + pass + + +class BrokenToolset(BaseToolset): + + async def get_tools( + self, + readonly_context: Optional[ReadonlyContext] = None, + ) -> list[BaseTool]: + raise FakeError() + + async def close(self): + pass + + +def test_transfer_error_handling(): + response = [ + transfer_call_part('sub_agent_1'), + 'response1', + 'response2', + ] + mockModel = testing_utils.MockModel.create(responses=response) + # root (auto) - sub_agent_1 (auto) + + sub_agent_1 = Agent( + name='sub_agent_1', + model=mockModel, + tools=[ + BrokenToolset(), + ], + ) + root_agent = Agent( + name='root_agent', + model=mockModel, + sub_agents=[sub_agent_1], + ) + + runner = testing_utils.InMemoryRunner(root_agent) + + # Asserts the transfer. + assert testing_utils.simplify_events(runner.run('test1')) == [ + ('root_agent', transfer_call_part('sub_agent_1')), + ( + 'root_agent', + Part.from_function_response( + name='transfer_to_agent', + response={'result': 'Error transferring to sub_agent_1'}, + ), + ), + ] + + # root_agent should still be the current agent as the transfer failed. + assert testing_utils.simplify_events(runner.run('test2')) == [ + ('root_agent', 'response1'), + ] From ef8ec81cc92a3ef1c6125aa471bae36e968f6386 Mon Sep 17 00:00:00 2001 From: Calvin Giles Date: Sat, 7 Jun 2025 12:13:27 +1200 Subject: [PATCH 7/7] Policy matching update --- src/google/adk/approval/approval_policy.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/src/google/adk/approval/approval_policy.py b/src/google/adk/approval/approval_policy.py index 1a0dada26b..6c8c705e4e 100644 --- a/src/google/adk/approval/approval_policy.py +++ b/src/google/adk/approval/approval_policy.py @@ -128,8 +128,13 @@ def register_tool_policy(cls, policy: FunctionToolPolicy): """ if policy.policy_name is None: raise ValueError("Policy name cannot be None") - if policy not in cls.tool_policies: - cls.tool_policies.append(policy) + for existing_policy in cls.tool_policies: + if existing_policy.policy_name != policy.policy_name: + continue + if existing_policy.actions != policy.actions: + continue + return + cls.tool_policies.append(policy) @classmethod def get_tool_policies(cls, tool_name: str) -> list[FunctionToolPolicy]: