Skip to content

Commit

Permalink
feat: Automatically populate parents for full resource name in Vertex…
Browse files Browse the repository at this point in the history
… RAG SDK

PiperOrigin-RevId: 629849569
  • Loading branch information
yinghsienwu authored and Copybara-Service committed May 1, 2024
1 parent 2d19137 commit 26657ff
Show file tree
Hide file tree
Showing 3 changed files with 112 additions and 15 deletions.
29 changes: 26 additions & 3 deletions tests/unit/vertex_rag/test_rag_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,9 @@
import importlib
from google.api_core import operation as ga_operation
from vertexai.preview import rag
from vertexai.preview.rag.utils._gapic_utils import prepare_import_files_request
from vertexai.preview.rag.utils._gapic_utils import (
prepare_import_files_request,
)
from google.cloud.aiplatform_v1beta1 import (
VertexRagDataServiceAsyncClient,
VertexRagDataServiceClient,
Expand Down Expand Up @@ -184,6 +186,11 @@ def test_get_corpus_success(self):
rag_corpus = rag.get_corpus(tc.TEST_RAG_CORPUS_RESOURCE_NAME)
rag_corpus_eq(rag_corpus, tc.TEST_RAG_CORPUS)

@pytest.mark.usefixtures("rag_data_client_mock")
def test_get_corpus_id_success(self):
rag_corpus = rag.get_corpus(tc.TEST_RAG_CORPUS_ID)
rag_corpus_eq(rag_corpus, tc.TEST_RAG_CORPUS)

@pytest.mark.usefixtures("rag_data_client_mock_exception")
def test_get_corpus_failure(self):
with pytest.raises(RuntimeError) as e:
Expand All @@ -208,7 +215,11 @@ def test_list_corpora_failure(self):

def test_delete_corpus_success(self, rag_data_client_mock):
rag.delete_corpus(tc.TEST_RAG_CORPUS_RESOURCE_NAME)
rag_data_client_mock.assert_called_once()
assert rag_data_client_mock.call_count == 2

def test_delete_corpus_id_success(self, rag_data_client_mock):
rag.delete_corpus(tc.TEST_RAG_CORPUS_ID)
assert rag_data_client_mock.call_count == 2

@pytest.mark.usefixtures("rag_data_client_mock_exception")
def test_delete_corpus_failure(self):
Expand Down Expand Up @@ -311,6 +322,13 @@ def test_get_file_success(self):
rag_file = rag.get_file(tc.TEST_RAG_FILE_RESOURCE_NAME)
rag_file_eq(rag_file, tc.TEST_RAG_FILE)

@pytest.mark.usefixtures("rag_data_client_mock")
def test_get_file_id_success(self):
rag_file = rag.get_file(
name=tc.TEST_RAG_FILE_ID, corpus_name=tc.TEST_RAG_CORPUS_ID
)
rag_file_eq(rag_file, tc.TEST_RAG_FILE)

@pytest.mark.usefixtures("rag_data_client_mock_exception")
def test_get_file_failure(self):
with pytest.raises(RuntimeError) as e:
Expand All @@ -333,7 +351,12 @@ def test_list_files_failure(self):

def test_delete_file_success(self, rag_data_client_mock):
rag.delete_file(tc.TEST_RAG_FILE_RESOURCE_NAME)
rag_data_client_mock.assert_called_once()
assert rag_data_client_mock.call_count == 2

def test_delete_file_id_success(self, rag_data_client_mock):
rag.delete_file(name=tc.TEST_RAG_FILE_ID, corpus_name=tc.TEST_RAG_CORPUS_ID)
# Passing corpus_name will result in 3 calls to rag_data_client
assert rag_data_client_mock.call_count == 3

@pytest.mark.usefixtures("rag_data_client_mock_exception")
def test_delete_file_failure(self):
Expand Down
49 changes: 37 additions & 12 deletions vertexai/preview/rag/rag_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@
from google.cloud.aiplatform_v1beta1.services.vertex_rag_data_service.pagers import (
ListRagCorporaPager,
ListRagFilesPager,

)
from vertexai.preview.rag.utils import (
_gapic_utils,
Expand Down Expand Up @@ -100,10 +99,12 @@ def get_corpus(name: str) -> RagCorpus:
Args:
name: An existing RagCorpus resource name. Format:
``projects/{project}/locations/{location}/ragCorpora/{rag_corpus}``
or ``{rag_corpus}``.
Returns:
RagCorpus.
"""
request = GetRagCorpusRequest(name=name)
corpus_name = _gapic_utils.get_corpus_name(name)
request = GetRagCorpusRequest(name=corpus_name)
client = _gapic_utils.create_rag_data_service_client()
try:
response = client.get_rag_corpus(request=request)
Expand Down Expand Up @@ -163,8 +164,10 @@ def delete_corpus(name: str) -> None:
Args:
name: An existing RagCorpus resource name. Format:
``projects/{project}/locations/{location}/ragCorpora/{rag_corpus}``
or ``{rag_corpus}``.
"""
request = DeleteRagCorpusRequest(name=name)
corpus_name = _gapic_utils.get_corpus_name(name)
request = DeleteRagCorpusRequest(name=corpus_name)

client = _gapic_utils.create_rag_data_service_client()
try:
Expand Down Expand Up @@ -200,7 +203,8 @@ def upload_file(
Args:
corpus_name: The name of the RagCorpus resource into which to upload the file.
Format: ``projects/{project}/locations/{location}/ragCorpora/{rag_corpus}``
Format: ``projects/{project}/locations/{location}/ragCorpora/{rag_corpus}``
or ``{rag_corpus}``.
path: A local file path. For example,
"usr/home/my_file.txt".
display_name: The display name of the data file.
Expand All @@ -212,6 +216,7 @@ def upload_file(
ValueError: RagCorpus is not found.
RuntimeError: Failed in indexing the RagFile.
"""
corpus_name = _gapic_utils.get_corpus_name(corpus_name)
location = initializer.global_config.location
# GAPIC doesn't expose a path (scotty). Use requests API instead
if display_name is None:
Expand Down Expand Up @@ -286,6 +291,7 @@ def import_files(
Args:
corpus_name: The name of the RagCorpus resource into which to import files.
Format: ``projects/{project}/locations/{location}/ragCorpora/{rag_corpus}``
or ``{rag_corpus}``.
paths: A list of uris. Elligible uris will be Google Cloud Storage
directory ("gs://my-bucket/my_dir") or a Google Drive url for file
(https://drive.google.com/file/... or folder
Expand All @@ -296,7 +302,7 @@ def import_files(
Returns:
ImportRagFilesResponse.
"""

corpus_name = _gapic_utils.get_corpus_name(corpus_name)
request = _gapic_utils.prepare_import_files_request(
corpus_name=corpus_name,
paths=paths,
Expand Down Expand Up @@ -347,6 +353,7 @@ async def import_files_async(
Args:
corpus_name: The name of the RagCorpus resource into which to import files.
Format: ``projects/{project}/locations/{location}/ragCorpora/{rag_corpus}``
or ``{rag_corpus}``.
paths: A list of uris. Elligible uris will be Google Cloud Storage
directory ("gs://my-bucket/my_dir") or a Google Drive url for file
(https://drive.google.com/file/... or folder
Expand All @@ -356,7 +363,7 @@ async def import_files_async(
Returns:
operation_async.AsyncOperation.
"""

corpus_name = _gapic_utils.get_corpus_name(corpus_name)
request = _gapic_utils.prepare_import_files_request(
corpus_name=corpus_name,
paths=paths,
Expand All @@ -371,16 +378,24 @@ async def import_files_async(
return response


def get_file(name: str) -> RagFile:
def get_file(name: str, corpus_name: Optional[str] = None) -> RagFile:
"""
Get an existing RagFile.
Args:
name: A RagFile resource name. Format:
``projects/{project}/locations/{location}/ragCorpora/{rag_corpus}/ragFiles/{rag_file}``
name: Either a full RagFile resource name must be provided, or a RagCorpus
name and a RagFile name must be provided. Format:
``projects/{project}/locations/{location}/ragCorpora/{rag_corpus}/ragFiles/{rag_file}``
or ``{rag_file}``.
corpus_name: If `name` is not a full resource name, an existing RagCorpus
name must be provided. Format:
``projects/{project}/locations/{location}/ragCorpora/{rag_corpus}``
or ``{rag_corpus}``.
Returns:
RagFile.
"""
corpus_name = _gapic_utils.get_corpus_name(corpus_name)
name = _gapic_utils.get_file_name(name, corpus_name)
request = GetRagFileRequest(name=name)
client = _gapic_utils.create_rag_data_service_client()
try:
Expand Down Expand Up @@ -423,13 +438,15 @@ def list_files(
Args:
corpus_name: An existing RagCorpus name. Format:
``projects/{project}/locations/{location}/ragCorpora/{rag_corpus}``
``projects/{project}/locations/{location}/ragCorpora/{rag_corpus}``
or ``{rag_corpus}``.
page_size: The standard list page size. Leaving out the page_size
causes all of the results to be returned.
page_token: The standard list page token.
Returns:
ListRagFilesPager.
"""
corpus_name = _gapic_utils.get_corpus_name(corpus_name)
request = ListRagFilesRequest(
parent=corpus_name,
page_size=page_size,
Expand All @@ -444,14 +461,22 @@ def list_files(
return pager


def delete_file(name: str) -> None:
def delete_file(name: str, corpus_name: Optional[str] = None) -> None:
"""
Delete RagFile from an existing RagCorpus.
Args:
name: A RagFile resource name. Format:
name: Either a full RagFile resource name must be provided, or a RagCorpus
name and a RagFile name must be provided. Format:
``projects/{project}/locations/{location}/ragCorpora/{rag_corpus}/ragFiles/{rag_file}``
or ``{rag_file}``.
corpus_name: If `name` is not a full resource name, an existing RagCorpus
name must be provided. Format:
``projects/{project}/locations/{location}/ragCorpora/{rag_corpus}``
or ``{rag_corpus}``.
"""
corpus_name = _gapic_utils.get_corpus_name(corpus_name)
name = _gapic_utils.get_file_name(name, corpus_name)
request = DeleteRagFileRequest(name=name)

client = _gapic_utils.create_rag_data_service_client()
Expand Down
49 changes: 49 additions & 0 deletions vertexai/preview/rag/utils/_gapic_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
import re
from typing import Any, Dict, Sequence, Union
from google.cloud.aiplatform_v1beta1 import (
GoogleDriveSource,
Expand All @@ -35,6 +36,9 @@
)


_VALID_RESOURCE_NAME_REGEX = "[a-z][a-zA-Z0-9._-]{0,127}"


def create_rag_data_service_client():
return initializer.global_config.create_client(
client_class=VertexRagDataClientWithOverride,
Expand Down Expand Up @@ -153,3 +157,48 @@ def prepare_import_files_request(
parent=corpus_name, import_rag_files_config=import_rag_files_config
)
return request


def get_corpus_name(
name: str,
) -> str:
if name:
client = create_rag_data_service_client()
if client.parse_rag_corpus_path(name):
return name
elif re.match("^{}$".format(_VALID_RESOURCE_NAME_REGEX), name):
return client.rag_corpus_path(
project=initializer.global_config.project,
location=initializer.global_config.location,
rag_corpus=name,
)
else:
raise ValueError(
"name must be of the format `projects/{project}/locations/{location}/ragCorpora/{rag_corpus}` or `{rag_corpus}`"
)
return name


def get_file_name(
name: str,
corpus_name: str,
) -> str:
client = create_rag_data_service_client()
if client.parse_rag_file_path(name):
return name
elif re.match("^{}$".format(_VALID_RESOURCE_NAME_REGEX), name):
if not corpus_name:
raise ValueError(
"corpus_name must be provided if name is a `{rag_file}`, not a "
"full resource name (`projects/{project}/locations/{location}/ragCorpora/{rag_corpus}/ragFiles/{rag_file}`). "
)
return client.rag_file_path(
project=initializer.global_config.project,
location=initializer.global_config.location,
rag_corpus=get_corpus_name(corpus_name),
rag_file=name,
)
else:
raise ValueError(
"name must be of the format `projects/{project}/locations/{location}/ragCorpora/{rag_corpus}/ragFiles/{rag_file}` or `{rag_file}`"
)

0 comments on commit 26657ff

Please sign in to comment.