diff --git a/vertexai/_genai/_agent_engines_utils.py b/vertexai/_genai/_agent_engines_utils.py index 485a1c50be..4117f60171 100644 --- a/vertexai/_genai/_agent_engines_utils.py +++ b/vertexai/_genai/_agent_engines_utils.py @@ -14,10 +14,22 @@ # """Utility functions for agent engines.""" -from typing import Any, Callable, Coroutine, Iterator, AsyncIterator +from typing import Any, Callable, Coroutine, Iterator, AsyncIterator, Protocol, Union from . import types +AgentEngineOperationUnion = Union[ + types.AgentEngineOperation, + types.AgentEngineMemoryOperation, + types.AgentEngineGenerateMemoriesOperation, +] + + +class GetOperationFunction(Protocol): + def __call__(self, *, operation_name: str, **kwargs) -> AgentEngineOperationUnion: + ... + + def _wrap_query_operation(*, method_name: str) -> Callable[..., Any]: """Wraps an Agent Engine method, creating a callable for `query` API. diff --git a/vertexai/_genai/agent_engines.py b/vertexai/_genai/agent_engines.py index 58bd9f25a4..2985fd570d 100644 --- a/vertexai/_genai/agent_engines.py +++ b/vertexai/_genai/agent_engines.py @@ -18,12 +18,11 @@ import json import logging import time -from typing import Any, Callable, Iterator, Optional, Sequence, Union +from typing import Any, Iterator, Optional, Sequence, Tuple, Union from urllib.parse import urlencode from google.genai import _api_module from google.genai import _common -from google.genai import types as genai_types from google.genai._common import get_value_by_path as getv from google.genai._common import set_value_by_path as setv from google.genai.pagers import Pager @@ -2045,10 +2044,13 @@ def create( api_async_client=AsyncAgentEngines(api_client_=self._api_client), api_resource=operation.response, ) - logger.info("Agent Engine created. To use it in another session:") - logger.info( - f"agent_engine=client.agent_engines.get('{agent.api_resource.name}')" - ) + if agent.api_resource: + logger.info("Agent Engine created. To use it in another session:") + logger.info( + f"agent_engine=client.agent_engines.get('{agent.api_resource.name}')" + ) + else: + logger.warning("The operation returned an empty response.") if agent_engine is not None: # If the user did not provide an agent_engine (e.g. lightweight # provisioning), it will not have any API methods registered. @@ -2067,13 +2069,13 @@ def _create_config( gcs_dir_name: Optional[str] = None, extra_packages: Optional[Sequence[str]] = None, env_vars: Optional[dict[str, Union[str, Any]]] = None, - context_spec: Optional[dict[str, Any]] = None, - ): + context_spec: Optional[types.ReasoningEngineContextSpecDict] = None, + ) -> types.UpdateAgentEngineConfigDict: import sys from vertexai.agent_engines import _agent_engines from vertexai.agent_engines import _utils - config = {} + config: types.UpdateAgentEngineConfigDict = {} update_masks = [] if mode not in ["create", "update"]: raise ValueError(f"Unsupported mode: {mode}") @@ -2091,19 +2093,27 @@ def _create_config( if context_spec is not None: config["context_spec"] = context_spec if agent_engine is not None: + project = self._api_client.project + if project is None: + raise ValueError("project must be set using `vertexai.Client`.") + location = self._api_client.location + if location is None: + raise ValueError("location must be set using `vertexai.Client`.") sys_version = f"{sys.version_info.major}.{sys.version_info.minor}" gcs_dir_name = gcs_dir_name or _agent_engines._DEFAULT_GCS_DIR_NAME agent_engine = _agent_engines._validate_agent_engine_or_raise( agent_engine=agent_engine, logger=logger ) - _agent_engines._validate_staging_bucket_or_raise(staging_bucket) + staging_bucket = _agent_engines._validate_staging_bucket_or_raise( + staging_bucket=staging_bucket + ) requirements = _agent_engines._validate_requirements_or_raise( agent_engine=agent_engine, requirements=requirements, logger=logger, ) extra_packages = _agent_engines._validate_extra_packages_or_raise( - extra_packages + extra_packages=extra_packages, ) # Prepares the Agent Engine for creation/update in Vertex AI. This # involves packaging and uploading the artifacts for agent_engine, @@ -2111,8 +2121,8 @@ def _create_config( _agent_engines._prepare( agent_engine=agent_engine, requirements=requirements, - project=self._api_client.project, - location=self._api_client.location, + project=project, + location=location, staging_bucket=staging_bucket, gcs_dir_name=gcs_dir_name, extra_packages=extra_packages, @@ -2142,7 +2152,9 @@ def _create_config( gcs_dir_name, _agent_engines._REQUIREMENTS_FILE, ) - agent_engine_spec = {"package_spec": package_spec} + agent_engine_spec: types.ReasoningEngineSpecDict = { + "package_spec": package_spec, + } if env_vars is not None: ( deployment_spec, @@ -2172,7 +2184,7 @@ def _generate_deployment_spec_or_raise( self, *, env_vars: Optional[dict[str, Union[str, Any]]] = None, - ): + ) -> Tuple[dict[str, Any], Sequence[str]]: deployment_spec: dict[str, Any] = {} update_masks = [] if env_vars: @@ -2217,7 +2229,7 @@ def _await_operation( *, operation_name: str, poll_interval_seconds: int = 10, - get_operation_fn: Optional[Callable[[str], Any]] = None, + get_operation_fn: Optional[_agent_engines_utils.GetOperationFunction] = None, ) -> Any: """Waits for the operation for creating an agent engine to complete. @@ -2375,10 +2387,11 @@ def update( api_async_client=AsyncAgentEngines(api_client_=self._api_client), api_resource=operation.response, ) - logger.info("Agent Engine updated. To use it in another session:") - logger.info( - f"agent_engine=client.agent_engines.get('{agent.api_resource.name}')" - ) + if agent.api_resource: + logger.info("Agent Engine updated. To use it in another session:") + logger.info( + f"agent_engine=client.agent_engines.get('{agent.api_resource.name}')" + ) return self._register_api_methods(agent=agent) def _stream_query( @@ -2403,7 +2416,7 @@ def _stream_query( path = f"{path}?{urlencode(query_params)}" # TODO: remove the hack that pops config. request_dict.pop("config", None) - http_options: Optional[genai_types.HttpOptions] = None + http_options = None if ( parameter_model.config is not None and parameter_model.config.http_options is not None diff --git a/vertexai/agent_engines/_agent_engines.py b/vertexai/agent_engines/_agent_engines.py index 1bceefd3de..3ac9ff8bf3 100644 --- a/vertexai/agent_engines/_agent_engines.py +++ b/vertexai/agent_engines/_agent_engines.py @@ -16,6 +16,7 @@ import abc import inspect import io +import logging import os import sys import tarfile @@ -66,7 +67,7 @@ _DEFAULT_STREAM_METHOD_NAME = "stream_query" _DEFAULT_ASYNC_STREAM_METHOD_NAME = "async_stream_query" _DEFAULT_METHOD_RETURN_TYPE = "dict[str, Any]" -_DEFAULT_ASYNC_METHOD_RETURN_TYPE = "Coroutine[Any]" +_DEFAULT_ASYNC_METHOD_RETURN_TYPE = "Coroutine[Any, Any, Any]" _DEFAULT_STREAM_METHOD_RETURN_TYPE = "Iterable[Any]" _DEFAULT_ASYNC_STREAM_METHOD_RETURN_TYPE = "AsyncIterable[Any]" _DEFAULT_METHOD_DOCSTRING_TEMPLATE = """ @@ -116,7 +117,7 @@ class Queryable(Protocol): """Protocol for Agent Engines that can be queried.""" @abc.abstractmethod - def query(self, **kwargs): + def query(self, **kwargs) -> Any: """Runs the Agent Engine to serve the user query.""" @@ -125,7 +126,7 @@ class AsyncQueryable(Protocol): """Protocol for Agent Engines that can be queried asynchronously.""" @abc.abstractmethod - def async_query(self, **kwargs): + def async_query(self, **kwargs) -> Coroutine[Any, Any, Any]: """Runs the Agent Engine to serve the user query asynchronously.""" @@ -143,7 +144,7 @@ class StreamQueryable(Protocol): """Protocol for Agent Engines that can stream responses.""" @abc.abstractmethod - def stream_query(self, **kwargs): + def stream_query(self, **kwargs) -> Iterable[Any]: """Stream responses to serve the user query.""" @@ -152,7 +153,7 @@ class Cloneable(Protocol): """Protocol for Agent Engines that can be cloned.""" @abc.abstractmethod - def clone(self): + def clone(self) -> Any: """Return a clone of the object.""" @@ -161,7 +162,7 @@ class OperationRegistrable(Protocol): """Protocol for agents that have registered operations.""" @abc.abstractmethod - def register_operations(self, **kwargs): + def register_operations(self, **kwargs) -> Dict[str, Sequence[str]]: """Register the user provided operations (modes and methods).""" @@ -238,7 +239,7 @@ def clone(self): sys_paths=self._tmpl_attrs.get("sys_paths"), ) - def register_operations(self) -> Dict[str, Sequence[str]]: + def register_operations(self, **kwargs) -> Dict[str, Sequence[str]]: return self._tmpl_attrs.get("register_operations") def set_up(self) -> None: @@ -441,7 +442,7 @@ def create( staging_bucket = initializer.global_config.staging_bucket if agent_engine is not None: agent_engine = _validate_agent_engine_or_raise(agent_engine) - _validate_staging_bucket_or_raise(staging_bucket) + staging_bucket = _validate_staging_bucket_or_raise(staging_bucket) if agent_engine is None: if requirements is not None: raise ValueError("requirements must be None if agent_engine is None.") @@ -634,7 +635,7 @@ def update( nonexistent file. """ staging_bucket = initializer.global_config.staging_bucket - _validate_staging_bucket_or_raise(staging_bucket) + staging_bucket = _validate_staging_bucket_or_raise(staging_bucket) historical_operation_schemas = self.operation_schemas() gcs_dir_name = gcs_dir_name or _DEFAULT_GCS_DIR_NAME @@ -780,12 +781,13 @@ def _validate_sys_version_or_raise(sys_version: str) -> None: ) -def _validate_staging_bucket_or_raise(staging_bucket: str) -> str: +def _validate_staging_bucket_or_raise(staging_bucket: Optional[str]) -> str: """Tries to validate the staging bucket.""" if not staging_bucket: raise ValueError("Please provide a `staging_bucket` in `vertexai.init(...)`") if not staging_bucket.startswith("gs://"): raise ValueError(f"{staging_bucket=} must start with `gs://`") + return staging_bucket def _validate_agent_engine_or_raise( @@ -906,7 +908,7 @@ def _validate_requirements_or_raise( *, agent_engine: _AgentEngineInterface, requirements: Optional[Sequence[str]] = None, - logger: base.Logger = _LOGGER, + logger: logging.getLoggerClass() = _LOGGER, ) -> Sequence[str]: """Tries to validate the requirements.""" if requirements is None: @@ -929,7 +931,7 @@ def _validate_requirements_or_raise( def _validate_extra_packages_or_raise( - extra_packages: Sequence[str], + extra_packages: Optional[Sequence[str]], build_options: Optional[Dict[str, Sequence[str]]] = None, ) -> Sequence[str]: """Tries to validates the extra packages.""" @@ -1165,6 +1167,7 @@ def _get_agent_framework( if ( hasattr(agent_engine, _AGENT_FRAMEWORK_ATTR) and getattr(agent_engine, _AGENT_FRAMEWORK_ATTR) is not None + and isinstance(getattr(agent_engine, _AGENT_FRAMEWORK_ATTR), str) ): return getattr(agent_engine, _AGENT_FRAMEWORK_ATTR) return _DEFAULT_AGENT_FRAMEWORK diff --git a/vertexai/reasoning_engines/_utils.py b/vertexai/reasoning_engines/_utils.py index ec16dbb386..338cf3c9f6 100644 --- a/vertexai/reasoning_engines/_utils.py +++ b/vertexai/reasoning_engines/_utils.py @@ -103,7 +103,7 @@ def to_dict(message: proto.Message) -> JsonDict: return result -def dataclass_to_dict(obj: dataclasses.dataclass) -> JsonDict: +def dataclass_to_dict(obj: dataclasses.dataclass) -> Any: """Converts a dataclass to a JSON dictionary. Args: @@ -116,7 +116,7 @@ def dataclass_to_dict(obj: dataclasses.dataclass) -> JsonDict: return json.loads(json.dumps(dataclasses.asdict(obj))) -def _llama_index_response_to_dict(obj: LlamaIndexResponse) -> Dict[str, Any]: +def _llama_index_response_to_dict(obj: LlamaIndexResponse) -> Any: response = {} if hasattr(obj, "response"): response["response"] = obj.response @@ -128,15 +128,11 @@ def _llama_index_response_to_dict(obj: LlamaIndexResponse) -> Dict[str, Any]: return json.loads(json.dumps(response)) -def _llama_index_chat_response_to_dict( - obj: LlamaIndexChatResponse, -) -> Dict[str, Any]: +def _llama_index_chat_response_to_dict(obj: LlamaIndexChatResponse) -> Any: return json.loads(obj.message.model_dump_json()) -def _llama_index_base_model_to_dict( - obj: LlamaIndexBaseModel, -) -> Dict[str, Any]: +def _llama_index_base_model_to_dict(obj: LlamaIndexBaseModel) -> Any: return json.loads(obj.model_dump_json()) @@ -330,7 +326,7 @@ def _import_cloud_storage_or_raise() -> types.ModuleType: except ImportError as e: raise ImportError( "Cloud Storage is not installed. Please call " - "'pip install google-cloud-aiplatform[reasoningengine]'." + "'pip install google-cloud-aiplatform[agent_engines]'." ) from e return storage @@ -342,7 +338,7 @@ def _import_cloudpickle_or_raise() -> types.ModuleType: except ImportError as e: raise ImportError( "cloudpickle is not installed. Please call " - "'pip install google-cloud-aiplatform[reasoningengine]'." + "'pip install google-cloud-aiplatform[agent_engines]'." ) from e return cloudpickle @@ -358,7 +354,7 @@ def _import_pydantic_or_raise() -> types.ModuleType: except ImportError as e: raise ImportError( "pydantic is not installed. Please call " - "'pip install google-cloud-aiplatform[reasoningengine]'." + "'pip install google-cloud-aiplatform[agent_engines]'." ) from e return pydantic @@ -372,7 +368,7 @@ def _import_opentelemetry_or_warn() -> Optional[types.ModuleType]: except ImportError: _LOGGER.warning( "opentelemetry-sdk is not installed. Please call " - "'pip install google-cloud-aiplatform[reasoningengine]'." + "'pip install google-cloud-aiplatform[agent_engines]'." ) return None @@ -386,7 +382,7 @@ def _import_opentelemetry_sdk_trace_or_warn() -> Optional[types.ModuleType]: except ImportError: _LOGGER.warning( "opentelemetry-sdk is not installed. Please call " - "'pip install google-cloud-aiplatform[reasoningengine]'." + "'pip install google-cloud-aiplatform[agent_engines]'." ) return None @@ -400,7 +396,7 @@ def _import_cloud_trace_v2_or_warn() -> Optional[types.ModuleType]: except ImportError: _LOGGER.warning( "google-cloud-trace is not installed. Please call " - "'pip install google-cloud-aiplatform[reasoningengine]'." + "'pip install google-cloud-aiplatform[agent_engines]'." ) return None