Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 13 additions & 1 deletion vertexai/_genai/_agent_engines_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,22 @@
#
"""Utility functions for agent engines."""

from typing import Any, Callable, Coroutine, Iterator, AsyncIterator
from typing import Any, Callable, Coroutine, Iterator, AsyncIterator, Protocol, Union
from . import types


AgentEngineOperationUnion = Union[
types.AgentEngineOperation,
types.AgentEngineMemoryOperation,
types.AgentEngineGenerateMemoriesOperation,
]


class GetOperationFunction(Protocol):
def __call__(self, *, operation_name: str, **kwargs) -> AgentEngineOperationUnion:
...


def _wrap_query_operation(*, method_name: str) -> Callable[..., Any]:
"""Wraps an Agent Engine method, creating a callable for `query` API.

Expand Down
55 changes: 34 additions & 21 deletions vertexai/_genai/agent_engines.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,11 @@
import json
import logging
import time
from typing import Any, Callable, Iterator, Optional, Sequence, Union
from typing import Any, Iterator, Optional, Sequence, Tuple, Union
from urllib.parse import urlencode

from google.genai import _api_module
from google.genai import _common
from google.genai import types as genai_types
from google.genai._common import get_value_by_path as getv
from google.genai._common import set_value_by_path as setv
from google.genai.pagers import Pager
Expand Down Expand Up @@ -2045,10 +2044,13 @@ def create(
api_async_client=AsyncAgentEngines(api_client_=self._api_client),
api_resource=operation.response,
)
logger.info("Agent Engine created. To use it in another session:")
logger.info(
f"agent_engine=client.agent_engines.get('{agent.api_resource.name}')"
)
if agent.api_resource:
logger.info("Agent Engine created. To use it in another session:")
logger.info(
f"agent_engine=client.agent_engines.get('{agent.api_resource.name}')"
)
else:
logger.warning("The operation returned an empty response.")
if agent_engine is not None:
# If the user did not provide an agent_engine (e.g. lightweight
# provisioning), it will not have any API methods registered.
Expand All @@ -2067,13 +2069,13 @@ def _create_config(
gcs_dir_name: Optional[str] = None,
extra_packages: Optional[Sequence[str]] = None,
env_vars: Optional[dict[str, Union[str, Any]]] = None,
context_spec: Optional[dict[str, Any]] = None,
):
context_spec: Optional[types.ReasoningEngineContextSpecDict] = None,
) -> types.UpdateAgentEngineConfigDict:
import sys
from vertexai.agent_engines import _agent_engines
from vertexai.agent_engines import _utils

config = {}
config: types.UpdateAgentEngineConfigDict = {}
update_masks = []
if mode not in ["create", "update"]:
raise ValueError(f"Unsupported mode: {mode}")
Expand All @@ -2091,28 +2093,36 @@ def _create_config(
if context_spec is not None:
config["context_spec"] = context_spec
if agent_engine is not None:
project = self._api_client.project
if project is None:
raise ValueError("project must be set using `vertexai.Client`.")
location = self._api_client.location
if location is None:
raise ValueError("location must be set using `vertexai.Client`.")
sys_version = f"{sys.version_info.major}.{sys.version_info.minor}"
gcs_dir_name = gcs_dir_name or _agent_engines._DEFAULT_GCS_DIR_NAME
agent_engine = _agent_engines._validate_agent_engine_or_raise(
agent_engine=agent_engine, logger=logger
)
_agent_engines._validate_staging_bucket_or_raise(staging_bucket)
staging_bucket = _agent_engines._validate_staging_bucket_or_raise(
staging_bucket=staging_bucket
)
requirements = _agent_engines._validate_requirements_or_raise(
agent_engine=agent_engine,
requirements=requirements,
logger=logger,
)
extra_packages = _agent_engines._validate_extra_packages_or_raise(
extra_packages
extra_packages=extra_packages,
)
# Prepares the Agent Engine for creation/update in Vertex AI. This
# involves packaging and uploading the artifacts for agent_engine,
# requirements and extra_packages to `staging_bucket/gcs_dir_name`.
_agent_engines._prepare(
agent_engine=agent_engine,
requirements=requirements,
project=self._api_client.project,
location=self._api_client.location,
project=project,
location=location,
staging_bucket=staging_bucket,
gcs_dir_name=gcs_dir_name,
extra_packages=extra_packages,
Expand Down Expand Up @@ -2142,7 +2152,9 @@ def _create_config(
gcs_dir_name,
_agent_engines._REQUIREMENTS_FILE,
)
agent_engine_spec = {"package_spec": package_spec}
agent_engine_spec: types.ReasoningEngineSpecDict = {
"package_spec": package_spec,
}
if env_vars is not None:
(
deployment_spec,
Expand Down Expand Up @@ -2172,7 +2184,7 @@ def _generate_deployment_spec_or_raise(
self,
*,
env_vars: Optional[dict[str, Union[str, Any]]] = None,
):
) -> Tuple[dict[str, Any], Sequence[str]]:
deployment_spec: dict[str, Any] = {}
update_masks = []
if env_vars:
Expand Down Expand Up @@ -2217,7 +2229,7 @@ def _await_operation(
*,
operation_name: str,
poll_interval_seconds: int = 10,
get_operation_fn: Optional[Callable[[str], Any]] = None,
get_operation_fn: Optional[_agent_engines_utils.GetOperationFunction] = None,
) -> Any:
"""Waits for the operation for creating an agent engine to complete.

Expand Down Expand Up @@ -2375,10 +2387,11 @@ def update(
api_async_client=AsyncAgentEngines(api_client_=self._api_client),
api_resource=operation.response,
)
logger.info("Agent Engine updated. To use it in another session:")
logger.info(
f"agent_engine=client.agent_engines.get('{agent.api_resource.name}')"
)
if agent.api_resource:
logger.info("Agent Engine updated. To use it in another session:")
logger.info(
f"agent_engine=client.agent_engines.get('{agent.api_resource.name}')"
)
return self._register_api_methods(agent=agent)

def _stream_query(
Expand All @@ -2403,7 +2416,7 @@ def _stream_query(
path = f"{path}?{urlencode(query_params)}"
# TODO: remove the hack that pops config.
request_dict.pop("config", None)
http_options: Optional[genai_types.HttpOptions] = None
http_options = None
if (
parameter_model.config is not None
and parameter_model.config.http_options is not None
Expand Down
27 changes: 15 additions & 12 deletions vertexai/agent_engines/_agent_engines.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import abc
import inspect
import io
import logging
import os
import sys
import tarfile
Expand Down Expand Up @@ -66,7 +67,7 @@
_DEFAULT_STREAM_METHOD_NAME = "stream_query"
_DEFAULT_ASYNC_STREAM_METHOD_NAME = "async_stream_query"
_DEFAULT_METHOD_RETURN_TYPE = "dict[str, Any]"
_DEFAULT_ASYNC_METHOD_RETURN_TYPE = "Coroutine[Any]"
_DEFAULT_ASYNC_METHOD_RETURN_TYPE = "Coroutine[Any, Any, Any]"
_DEFAULT_STREAM_METHOD_RETURN_TYPE = "Iterable[Any]"
_DEFAULT_ASYNC_STREAM_METHOD_RETURN_TYPE = "AsyncIterable[Any]"
_DEFAULT_METHOD_DOCSTRING_TEMPLATE = """
Expand Down Expand Up @@ -116,7 +117,7 @@ class Queryable(Protocol):
"""Protocol for Agent Engines that can be queried."""

@abc.abstractmethod
def query(self, **kwargs):
def query(self, **kwargs) -> Any:
"""Runs the Agent Engine to serve the user query."""


Expand All @@ -125,7 +126,7 @@ class AsyncQueryable(Protocol):
"""Protocol for Agent Engines that can be queried asynchronously."""

@abc.abstractmethod
def async_query(self, **kwargs):
def async_query(self, **kwargs) -> Coroutine[Any, Any, Any]:
"""Runs the Agent Engine to serve the user query asynchronously."""


Expand All @@ -143,7 +144,7 @@ class StreamQueryable(Protocol):
"""Protocol for Agent Engines that can stream responses."""

@abc.abstractmethod
def stream_query(self, **kwargs):
def stream_query(self, **kwargs) -> Iterable[Any]:
"""Stream responses to serve the user query."""


Expand All @@ -152,7 +153,7 @@ class Cloneable(Protocol):
"""Protocol for Agent Engines that can be cloned."""

@abc.abstractmethod
def clone(self):
def clone(self) -> Any:
"""Return a clone of the object."""


Expand All @@ -161,7 +162,7 @@ class OperationRegistrable(Protocol):
"""Protocol for agents that have registered operations."""

@abc.abstractmethod
def register_operations(self, **kwargs):
def register_operations(self, **kwargs) -> Dict[str, Sequence[str]]:
"""Register the user provided operations (modes and methods)."""


Expand Down Expand Up @@ -238,7 +239,7 @@ def clone(self):
sys_paths=self._tmpl_attrs.get("sys_paths"),
)

def register_operations(self) -> Dict[str, Sequence[str]]:
def register_operations(self, **kwargs) -> Dict[str, Sequence[str]]:
return self._tmpl_attrs.get("register_operations")

def set_up(self) -> None:
Expand Down Expand Up @@ -441,7 +442,7 @@ def create(
staging_bucket = initializer.global_config.staging_bucket
if agent_engine is not None:
agent_engine = _validate_agent_engine_or_raise(agent_engine)
_validate_staging_bucket_or_raise(staging_bucket)
staging_bucket = _validate_staging_bucket_or_raise(staging_bucket)
if agent_engine is None:
if requirements is not None:
raise ValueError("requirements must be None if agent_engine is None.")
Expand Down Expand Up @@ -634,7 +635,7 @@ def update(
nonexistent file.
"""
staging_bucket = initializer.global_config.staging_bucket
_validate_staging_bucket_or_raise(staging_bucket)
staging_bucket = _validate_staging_bucket_or_raise(staging_bucket)
historical_operation_schemas = self.operation_schemas()
gcs_dir_name = gcs_dir_name or _DEFAULT_GCS_DIR_NAME

Expand Down Expand Up @@ -780,12 +781,13 @@ def _validate_sys_version_or_raise(sys_version: str) -> None:
)


def _validate_staging_bucket_or_raise(staging_bucket: str) -> str:
def _validate_staging_bucket_or_raise(staging_bucket: Optional[str]) -> str:
"""Tries to validate the staging bucket."""
if not staging_bucket:
raise ValueError("Please provide a `staging_bucket` in `vertexai.init(...)`")
if not staging_bucket.startswith("gs://"):
raise ValueError(f"{staging_bucket=} must start with `gs://`")
return staging_bucket


def _validate_agent_engine_or_raise(
Expand Down Expand Up @@ -906,7 +908,7 @@ def _validate_requirements_or_raise(
*,
agent_engine: _AgentEngineInterface,
requirements: Optional[Sequence[str]] = None,
logger: base.Logger = _LOGGER,
logger: logging.getLoggerClass() = _LOGGER,
) -> Sequence[str]:
"""Tries to validate the requirements."""
if requirements is None:
Expand All @@ -929,7 +931,7 @@ def _validate_requirements_or_raise(


def _validate_extra_packages_or_raise(
extra_packages: Sequence[str],
extra_packages: Optional[Sequence[str]],
build_options: Optional[Dict[str, Sequence[str]]] = None,
) -> Sequence[str]:
"""Tries to validates the extra packages."""
Expand Down Expand Up @@ -1165,6 +1167,7 @@ def _get_agent_framework(
if (
hasattr(agent_engine, _AGENT_FRAMEWORK_ATTR)
and getattr(agent_engine, _AGENT_FRAMEWORK_ATTR) is not None
and isinstance(getattr(agent_engine, _AGENT_FRAMEWORK_ATTR), str)
):
return getattr(agent_engine, _AGENT_FRAMEWORK_ATTR)
return _DEFAULT_AGENT_FRAMEWORK
Expand Down
24 changes: 10 additions & 14 deletions vertexai/reasoning_engines/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ def to_dict(message: proto.Message) -> JsonDict:
return result


def dataclass_to_dict(obj: dataclasses.dataclass) -> JsonDict:
def dataclass_to_dict(obj: dataclasses.dataclass) -> Any:
"""Converts a dataclass to a JSON dictionary.

Args:
Expand All @@ -116,7 +116,7 @@ def dataclass_to_dict(obj: dataclasses.dataclass) -> JsonDict:
return json.loads(json.dumps(dataclasses.asdict(obj)))


def _llama_index_response_to_dict(obj: LlamaIndexResponse) -> Dict[str, Any]:
def _llama_index_response_to_dict(obj: LlamaIndexResponse) -> Any:
response = {}
if hasattr(obj, "response"):
response["response"] = obj.response
Expand All @@ -128,15 +128,11 @@ def _llama_index_response_to_dict(obj: LlamaIndexResponse) -> Dict[str, Any]:
return json.loads(json.dumps(response))


def _llama_index_chat_response_to_dict(
obj: LlamaIndexChatResponse,
) -> Dict[str, Any]:
def _llama_index_chat_response_to_dict(obj: LlamaIndexChatResponse) -> Any:
return json.loads(obj.message.model_dump_json())


def _llama_index_base_model_to_dict(
obj: LlamaIndexBaseModel,
) -> Dict[str, Any]:
def _llama_index_base_model_to_dict(obj: LlamaIndexBaseModel) -> Any:
return json.loads(obj.model_dump_json())


Expand Down Expand Up @@ -330,7 +326,7 @@ def _import_cloud_storage_or_raise() -> types.ModuleType:
except ImportError as e:
raise ImportError(
"Cloud Storage is not installed. Please call "
"'pip install google-cloud-aiplatform[reasoningengine]'."
"'pip install google-cloud-aiplatform[agent_engines]'."
) from e
return storage

Expand All @@ -342,7 +338,7 @@ def _import_cloudpickle_or_raise() -> types.ModuleType:
except ImportError as e:
raise ImportError(
"cloudpickle is not installed. Please call "
"'pip install google-cloud-aiplatform[reasoningengine]'."
"'pip install google-cloud-aiplatform[agent_engines]'."
) from e
return cloudpickle

Expand All @@ -358,7 +354,7 @@ def _import_pydantic_or_raise() -> types.ModuleType:
except ImportError as e:
raise ImportError(
"pydantic is not installed. Please call "
"'pip install google-cloud-aiplatform[reasoningengine]'."
"'pip install google-cloud-aiplatform[agent_engines]'."
) from e
return pydantic

Expand All @@ -372,7 +368,7 @@ def _import_opentelemetry_or_warn() -> Optional[types.ModuleType]:
except ImportError:
_LOGGER.warning(
"opentelemetry-sdk is not installed. Please call "
"'pip install google-cloud-aiplatform[reasoningengine]'."
"'pip install google-cloud-aiplatform[agent_engines]'."
)
return None

Expand All @@ -386,7 +382,7 @@ def _import_opentelemetry_sdk_trace_or_warn() -> Optional[types.ModuleType]:
except ImportError:
_LOGGER.warning(
"opentelemetry-sdk is not installed. Please call "
"'pip install google-cloud-aiplatform[reasoningengine]'."
"'pip install google-cloud-aiplatform[agent_engines]'."
)
return None

Expand All @@ -400,7 +396,7 @@ def _import_cloud_trace_v2_or_warn() -> Optional[types.ModuleType]:
except ImportError:
_LOGGER.warning(
"google-cloud-trace is not installed. Please call "
"'pip install google-cloud-aiplatform[reasoningengine]'."
"'pip install google-cloud-aiplatform[agent_engines]'."
)
return None

Expand Down
Loading