diff --git a/tests/unit/vertexai/test_reasoning_engines.py b/tests/unit/vertexai/test_reasoning_engines.py index 7158838771..21a39aa03f 100644 --- a/tests/unit/vertexai/test_reasoning_engines.py +++ b/tests/unit/vertexai/test_reasoning_engines.py @@ -294,6 +294,36 @@ def test_create_reasoning_engine( retry=_TEST_RETRY, ) + def test_create_reasoning_engine_requirements_from_file( + self, + create_reasoning_engine_mock, + cloud_storage_create_bucket_mock, + tarfile_open_mock, + cloudpickle_dump_mock, + get_reasoning_engine_mock, + ): + with mock.patch( + "builtins.open", + mock.mock_open(read_data="google-cloud-aiplatform==1.29.0"), + ) as mock_file: + test_reasoning_engine = reasoning_engines.ReasoningEngine.create( + self.test_app, + reasoning_engine_name=_TEST_REASONING_ENGINE_RESOURCE_NAME, + display_name=_TEST_REASONING_ENGINE_DISPLAY_NAME, + requirements="requirements.txt", + ) + mock_file.assert_called_with("requirements.txt") + # Manually set _gca_resource here to prevent the mocks from propagating. + test_reasoning_engine._gca_resource = _TEST_REASONING_ENGINE_OBJ + create_reasoning_engine_mock.assert_called_with( + parent=_TEST_PARENT, + reasoning_engine=test_reasoning_engine.gca_resource, + ) + get_reasoning_engine_mock.assert_called_with( + name=_TEST_REASONING_ENGINE_RESOURCE_NAME, + retry=_TEST_RETRY, + ) + def test_delete_after_create_reasoning_engine( self, create_reasoning_engine_mock, @@ -407,6 +437,22 @@ def test_create_reasoning_engine_unsupported_sys_version( sys_version="2.6", ) + def test_create_reasoning_engine_requirements_ioerror( + self, + create_reasoning_engine_mock, + cloud_storage_create_bucket_mock, + tarfile_open_mock, + cloudpickle_dump_mock, + get_reasoning_engine_mock, + ): + with pytest.raises(IOError, match="Failed to read requirements"): + reasoning_engines.ReasoningEngine.create( + self.test_app, + reasoning_engine_name=_TEST_REASONING_ENGINE_RESOURCE_NAME, + display_name=_TEST_REASONING_ENGINE_DISPLAY_NAME, + requirements="nonexistent_requirements.txt", + ) + def test_create_reasoning_engine_nonexistent_extra_packages( self, create_reasoning_engine_mock, diff --git a/vertexai/reasoning_engines/_reasoning_engines.py b/vertexai/reasoning_engines/_reasoning_engines.py index 16e1e99174..77b38f7eae 100644 --- a/vertexai/reasoning_engines/_reasoning_engines.py +++ b/vertexai/reasoning_engines/_reasoning_engines.py @@ -20,7 +20,7 @@ import sys import tarfile import typing -from typing import Optional, Protocol, Sequence +from typing import Optional, Protocol, Sequence, Union from google.cloud.aiplatform import base from google.cloud.aiplatform import initializer @@ -84,7 +84,7 @@ def create( cls, reasoning_engine: Queryable, *, - requirements: Optional[Sequence[str]] = None, + requirements: Optional[Union[str, Sequence[str]]] = None, reasoning_engine_name: Optional[str] = None, display_name: Optional[str] = None, description: Optional[str] = None, @@ -131,8 +131,10 @@ def create( Args: reasoning_engine (ReasoningEngineInterface): Required. The Reasoning Engine to be created. - requirements (Sequence[str]): - Optional. The set of PyPI dependencies needed. + requirements (Union[str, Sequence[str]]): + Optional. The set of PyPI dependencies needed. It can either be + the path to a single file (requirements.txt), or an ordered list + of strings corresponding to each line of the requirements file. reasoning_engine_name (str): Optional. A fully-qualified resource name or ID such as "projects/123/locations/us-central1/reasoningEngines/456" or @@ -202,6 +204,16 @@ def create( "Invalid query signature. This might be due to a missing " "`self` argument in the reasoning_engine.query method." ) from err + if isinstance(requirements, str): + try: + _LOGGER.info(f"Reading requirements from {requirements=}") + with open(requirements) as f: + requirements = f.read().splitlines() + _LOGGER.info(f"Read the following lines: {requirements}") + except IOError as err: + raise IOError( + f"Failed to read requirements from {requirements=}" + ) from err requirements = requirements or [] extra_packages = extra_packages or [] for extra_package in extra_packages: