From 26657ffd25ecb91882ca764e513c2e952833257f Mon Sep 17 00:00:00 2001 From: Amy Wu Date: Wed, 1 May 2024 15:03:20 -0700 Subject: [PATCH] feat: Automatically populate parents for full resource name in Vertex RAG SDK PiperOrigin-RevId: 629849569 --- tests/unit/vertex_rag/test_rag_data.py | 29 +++++++++++-- vertexai/preview/rag/rag_data.py | 49 ++++++++++++++++------ vertexai/preview/rag/utils/_gapic_utils.py | 49 ++++++++++++++++++++++ 3 files changed, 112 insertions(+), 15 deletions(-) diff --git a/tests/unit/vertex_rag/test_rag_data.py b/tests/unit/vertex_rag/test_rag_data.py index cf1613eab7..6b2f3e98f7 100644 --- a/tests/unit/vertex_rag/test_rag_data.py +++ b/tests/unit/vertex_rag/test_rag_data.py @@ -17,7 +17,9 @@ import importlib from google.api_core import operation as ga_operation from vertexai.preview import rag -from vertexai.preview.rag.utils._gapic_utils import prepare_import_files_request +from vertexai.preview.rag.utils._gapic_utils import ( + prepare_import_files_request, +) from google.cloud.aiplatform_v1beta1 import ( VertexRagDataServiceAsyncClient, VertexRagDataServiceClient, @@ -184,6 +186,11 @@ def test_get_corpus_success(self): rag_corpus = rag.get_corpus(tc.TEST_RAG_CORPUS_RESOURCE_NAME) rag_corpus_eq(rag_corpus, tc.TEST_RAG_CORPUS) + @pytest.mark.usefixtures("rag_data_client_mock") + def test_get_corpus_id_success(self): + rag_corpus = rag.get_corpus(tc.TEST_RAG_CORPUS_ID) + rag_corpus_eq(rag_corpus, tc.TEST_RAG_CORPUS) + @pytest.mark.usefixtures("rag_data_client_mock_exception") def test_get_corpus_failure(self): with pytest.raises(RuntimeError) as e: @@ -208,7 +215,11 @@ def test_list_corpora_failure(self): def test_delete_corpus_success(self, rag_data_client_mock): rag.delete_corpus(tc.TEST_RAG_CORPUS_RESOURCE_NAME) - rag_data_client_mock.assert_called_once() + assert rag_data_client_mock.call_count == 2 + + def test_delete_corpus_id_success(self, rag_data_client_mock): + rag.delete_corpus(tc.TEST_RAG_CORPUS_ID) + assert rag_data_client_mock.call_count == 2 @pytest.mark.usefixtures("rag_data_client_mock_exception") def test_delete_corpus_failure(self): @@ -311,6 +322,13 @@ def test_get_file_success(self): rag_file = rag.get_file(tc.TEST_RAG_FILE_RESOURCE_NAME) rag_file_eq(rag_file, tc.TEST_RAG_FILE) + @pytest.mark.usefixtures("rag_data_client_mock") + def test_get_file_id_success(self): + rag_file = rag.get_file( + name=tc.TEST_RAG_FILE_ID, corpus_name=tc.TEST_RAG_CORPUS_ID + ) + rag_file_eq(rag_file, tc.TEST_RAG_FILE) + @pytest.mark.usefixtures("rag_data_client_mock_exception") def test_get_file_failure(self): with pytest.raises(RuntimeError) as e: @@ -333,7 +351,12 @@ def test_list_files_failure(self): def test_delete_file_success(self, rag_data_client_mock): rag.delete_file(tc.TEST_RAG_FILE_RESOURCE_NAME) - rag_data_client_mock.assert_called_once() + assert rag_data_client_mock.call_count == 2 + + def test_delete_file_id_success(self, rag_data_client_mock): + rag.delete_file(name=tc.TEST_RAG_FILE_ID, corpus_name=tc.TEST_RAG_CORPUS_ID) + # Passing corpus_name will result in 3 calls to rag_data_client + assert rag_data_client_mock.call_count == 3 @pytest.mark.usefixtures("rag_data_client_mock_exception") def test_delete_file_failure(self): diff --git a/vertexai/preview/rag/rag_data.py b/vertexai/preview/rag/rag_data.py index 31ecc31488..149ada8b1b 100644 --- a/vertexai/preview/rag/rag_data.py +++ b/vertexai/preview/rag/rag_data.py @@ -37,7 +37,6 @@ from google.cloud.aiplatform_v1beta1.services.vertex_rag_data_service.pagers import ( ListRagCorporaPager, ListRagFilesPager, - ) from vertexai.preview.rag.utils import ( _gapic_utils, @@ -100,10 +99,12 @@ def get_corpus(name: str) -> RagCorpus: Args: name: An existing RagCorpus resource name. Format: ``projects/{project}/locations/{location}/ragCorpora/{rag_corpus}`` + or ``{rag_corpus}``. Returns: RagCorpus. """ - request = GetRagCorpusRequest(name=name) + corpus_name = _gapic_utils.get_corpus_name(name) + request = GetRagCorpusRequest(name=corpus_name) client = _gapic_utils.create_rag_data_service_client() try: response = client.get_rag_corpus(request=request) @@ -163,8 +164,10 @@ def delete_corpus(name: str) -> None: Args: name: An existing RagCorpus resource name. Format: ``projects/{project}/locations/{location}/ragCorpora/{rag_corpus}`` + or ``{rag_corpus}``. """ - request = DeleteRagCorpusRequest(name=name) + corpus_name = _gapic_utils.get_corpus_name(name) + request = DeleteRagCorpusRequest(name=corpus_name) client = _gapic_utils.create_rag_data_service_client() try: @@ -200,7 +203,8 @@ def upload_file( Args: corpus_name: The name of the RagCorpus resource into which to upload the file. - Format: ``projects/{project}/locations/{location}/ragCorpora/{rag_corpus}`` + Format: ``projects/{project}/locations/{location}/ragCorpora/{rag_corpus}`` + or ``{rag_corpus}``. path: A local file path. For example, "usr/home/my_file.txt". display_name: The display name of the data file. @@ -212,6 +216,7 @@ def upload_file( ValueError: RagCorpus is not found. RuntimeError: Failed in indexing the RagFile. """ + corpus_name = _gapic_utils.get_corpus_name(corpus_name) location = initializer.global_config.location # GAPIC doesn't expose a path (scotty). Use requests API instead if display_name is None: @@ -286,6 +291,7 @@ def import_files( Args: corpus_name: The name of the RagCorpus resource into which to import files. Format: ``projects/{project}/locations/{location}/ragCorpora/{rag_corpus}`` + or ``{rag_corpus}``. paths: A list of uris. Elligible uris will be Google Cloud Storage directory ("gs://my-bucket/my_dir") or a Google Drive url for file (https://drive.google.com/file/... or folder @@ -296,7 +302,7 @@ def import_files( Returns: ImportRagFilesResponse. """ - + corpus_name = _gapic_utils.get_corpus_name(corpus_name) request = _gapic_utils.prepare_import_files_request( corpus_name=corpus_name, paths=paths, @@ -347,6 +353,7 @@ async def import_files_async( Args: corpus_name: The name of the RagCorpus resource into which to import files. Format: ``projects/{project}/locations/{location}/ragCorpora/{rag_corpus}`` + or ``{rag_corpus}``. paths: A list of uris. Elligible uris will be Google Cloud Storage directory ("gs://my-bucket/my_dir") or a Google Drive url for file (https://drive.google.com/file/... or folder @@ -356,7 +363,7 @@ async def import_files_async( Returns: operation_async.AsyncOperation. """ - + corpus_name = _gapic_utils.get_corpus_name(corpus_name) request = _gapic_utils.prepare_import_files_request( corpus_name=corpus_name, paths=paths, @@ -371,16 +378,24 @@ async def import_files_async( return response -def get_file(name: str) -> RagFile: +def get_file(name: str, corpus_name: Optional[str] = None) -> RagFile: """ Get an existing RagFile. Args: - name: A RagFile resource name. Format: - ``projects/{project}/locations/{location}/ragCorpora/{rag_corpus}/ragFiles/{rag_file}`` + name: Either a full RagFile resource name must be provided, or a RagCorpus + name and a RagFile name must be provided. Format: + ``projects/{project}/locations/{location}/ragCorpora/{rag_corpus}/ragFiles/{rag_file}`` + or ``{rag_file}``. + corpus_name: If `name` is not a full resource name, an existing RagCorpus + name must be provided. Format: + ``projects/{project}/locations/{location}/ragCorpora/{rag_corpus}`` + or ``{rag_corpus}``. Returns: RagFile. """ + corpus_name = _gapic_utils.get_corpus_name(corpus_name) + name = _gapic_utils.get_file_name(name, corpus_name) request = GetRagFileRequest(name=name) client = _gapic_utils.create_rag_data_service_client() try: @@ -423,13 +438,15 @@ def list_files( Args: corpus_name: An existing RagCorpus name. Format: - ``projects/{project}/locations/{location}/ragCorpora/{rag_corpus}`` + ``projects/{project}/locations/{location}/ragCorpora/{rag_corpus}`` + or ``{rag_corpus}``. page_size: The standard list page size. Leaving out the page_size causes all of the results to be returned. page_token: The standard list page token. Returns: ListRagFilesPager. """ + corpus_name = _gapic_utils.get_corpus_name(corpus_name) request = ListRagFilesRequest( parent=corpus_name, page_size=page_size, @@ -444,14 +461,22 @@ def list_files( return pager -def delete_file(name: str) -> None: +def delete_file(name: str, corpus_name: Optional[str] = None) -> None: """ Delete RagFile from an existing RagCorpus. Args: - name: A RagFile resource name. Format: + name: Either a full RagFile resource name must be provided, or a RagCorpus + name and a RagFile name must be provided. Format: ``projects/{project}/locations/{location}/ragCorpora/{rag_corpus}/ragFiles/{rag_file}`` + or ``{rag_file}``. + corpus_name: If `name` is not a full resource name, an existing RagCorpus + name must be provided. Format: + ``projects/{project}/locations/{location}/ragCorpora/{rag_corpus}`` + or ``{rag_corpus}``. """ + corpus_name = _gapic_utils.get_corpus_name(corpus_name) + name = _gapic_utils.get_file_name(name, corpus_name) request = DeleteRagFileRequest(name=name) client = _gapic_utils.create_rag_data_service_client() diff --git a/vertexai/preview/rag/utils/_gapic_utils.py b/vertexai/preview/rag/utils/_gapic_utils.py index ca429db980..6a8510ff23 100644 --- a/vertexai/preview/rag/utils/_gapic_utils.py +++ b/vertexai/preview/rag/utils/_gapic_utils.py @@ -14,6 +14,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # +import re from typing import Any, Dict, Sequence, Union from google.cloud.aiplatform_v1beta1 import ( GoogleDriveSource, @@ -35,6 +36,9 @@ ) +_VALID_RESOURCE_NAME_REGEX = "[a-z][a-zA-Z0-9._-]{0,127}" + + def create_rag_data_service_client(): return initializer.global_config.create_client( client_class=VertexRagDataClientWithOverride, @@ -153,3 +157,48 @@ def prepare_import_files_request( parent=corpus_name, import_rag_files_config=import_rag_files_config ) return request + + +def get_corpus_name( + name: str, +) -> str: + if name: + client = create_rag_data_service_client() + if client.parse_rag_corpus_path(name): + return name + elif re.match("^{}$".format(_VALID_RESOURCE_NAME_REGEX), name): + return client.rag_corpus_path( + project=initializer.global_config.project, + location=initializer.global_config.location, + rag_corpus=name, + ) + else: + raise ValueError( + "name must be of the format `projects/{project}/locations/{location}/ragCorpora/{rag_corpus}` or `{rag_corpus}`" + ) + return name + + +def get_file_name( + name: str, + corpus_name: str, +) -> str: + client = create_rag_data_service_client() + if client.parse_rag_file_path(name): + return name + elif re.match("^{}$".format(_VALID_RESOURCE_NAME_REGEX), name): + if not corpus_name: + raise ValueError( + "corpus_name must be provided if name is a `{rag_file}`, not a " + "full resource name (`projects/{project}/locations/{location}/ragCorpora/{rag_corpus}/ragFiles/{rag_file}`). " + ) + return client.rag_file_path( + project=initializer.global_config.project, + location=initializer.global_config.location, + rag_corpus=get_corpus_name(corpus_name), + rag_file=name, + ) + else: + raise ValueError( + "name must be of the format `projects/{project}/locations/{location}/ragCorpora/{rag_corpus}/ragFiles/{rag_file}` or `{rag_file}`" + )