Skip to content

Commit

Permalink
feat: Add support for reading requirements from a file.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 625068034
  • Loading branch information
Yeesian Ng authored and Copybara-Service committed Apr 15, 2024
1 parent 67de093 commit 80db7a0
Show file tree
Hide file tree
Showing 2 changed files with 62 additions and 4 deletions.
46 changes: 46 additions & 0 deletions tests/unit/vertexai/test_reasoning_engines.py
Expand Up @@ -294,6 +294,36 @@ def test_create_reasoning_engine(
retry=_TEST_RETRY,
)

def test_create_reasoning_engine_requirements_from_file(
self,
create_reasoning_engine_mock,
cloud_storage_create_bucket_mock,
tarfile_open_mock,
cloudpickle_dump_mock,
get_reasoning_engine_mock,
):
with mock.patch(
"builtins.open",
mock.mock_open(read_data="google-cloud-aiplatform==1.29.0"),
) as mock_file:
test_reasoning_engine = reasoning_engines.ReasoningEngine.create(
self.test_app,
reasoning_engine_name=_TEST_REASONING_ENGINE_RESOURCE_NAME,
display_name=_TEST_REASONING_ENGINE_DISPLAY_NAME,
requirements="requirements.txt",
)
mock_file.assert_called_with("requirements.txt")
# Manually set _gca_resource here to prevent the mocks from propagating.
test_reasoning_engine._gca_resource = _TEST_REASONING_ENGINE_OBJ
create_reasoning_engine_mock.assert_called_with(
parent=_TEST_PARENT,
reasoning_engine=test_reasoning_engine.gca_resource,
)
get_reasoning_engine_mock.assert_called_with(
name=_TEST_REASONING_ENGINE_RESOURCE_NAME,
retry=_TEST_RETRY,
)

def test_delete_after_create_reasoning_engine(
self,
create_reasoning_engine_mock,
Expand Down Expand Up @@ -407,6 +437,22 @@ def test_create_reasoning_engine_unsupported_sys_version(
sys_version="2.6",
)

def test_create_reasoning_engine_requirements_ioerror(
self,
create_reasoning_engine_mock,
cloud_storage_create_bucket_mock,
tarfile_open_mock,
cloudpickle_dump_mock,
get_reasoning_engine_mock,
):
with pytest.raises(IOError, match="Failed to read requirements"):
reasoning_engines.ReasoningEngine.create(
self.test_app,
reasoning_engine_name=_TEST_REASONING_ENGINE_RESOURCE_NAME,
display_name=_TEST_REASONING_ENGINE_DISPLAY_NAME,
requirements="nonexistent_requirements.txt",
)

def test_create_reasoning_engine_nonexistent_extra_packages(
self,
create_reasoning_engine_mock,
Expand Down
20 changes: 16 additions & 4 deletions vertexai/reasoning_engines/_reasoning_engines.py
Expand Up @@ -20,7 +20,7 @@
import sys
import tarfile
import typing
from typing import Optional, Protocol, Sequence
from typing import Optional, Protocol, Sequence, Union

from google.cloud.aiplatform import base
from google.cloud.aiplatform import initializer
Expand Down Expand Up @@ -84,7 +84,7 @@ def create(
cls,
reasoning_engine: Queryable,
*,
requirements: Optional[Sequence[str]] = None,
requirements: Optional[Union[str, Sequence[str]]] = None,
reasoning_engine_name: Optional[str] = None,
display_name: Optional[str] = None,
description: Optional[str] = None,
Expand Down Expand Up @@ -131,8 +131,10 @@ def create(
Args:
reasoning_engine (ReasoningEngineInterface):
Required. The Reasoning Engine to be created.
requirements (Sequence[str]):
Optional. The set of PyPI dependencies needed.
requirements (Union[str, Sequence[str]]):
Optional. The set of PyPI dependencies needed. It can either be
the path to a single file (requirements.txt), or an ordered list
of strings corresponding to each line of the requirements file.
reasoning_engine_name (str):
Optional. A fully-qualified resource name or ID such as
"projects/123/locations/us-central1/reasoningEngines/456" or
Expand Down Expand Up @@ -202,6 +204,16 @@ def create(
"Invalid query signature. This might be due to a missing "
"`self` argument in the reasoning_engine.query method."
) from err
if isinstance(requirements, str):
try:
_LOGGER.info(f"Reading requirements from {requirements=}")
with open(requirements) as f:
requirements = f.read().splitlines()
_LOGGER.info(f"Read the following lines: {requirements}")
except IOError as err:
raise IOError(
f"Failed to read requirements from {requirements=}"
) from err
requirements = requirements or []
extra_packages = extra_packages or []
for extra_package in extra_packages:
Expand Down

0 comments on commit 80db7a0

Please sign in to comment.