Skip to content

Commit

Permalink
feat: Vertex RAG for enhanced generative AI
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 627454806
  • Loading branch information
yinghsienwu authored and Copybara-Service committed Apr 23, 2024
1 parent 754c89d commit 39b5149
Show file tree
Hide file tree
Showing 14 changed files with 1,724 additions and 22 deletions.
21 changes: 18 additions & 3 deletions google/cloud/aiplatform/compat/services/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,12 +63,12 @@
from google.cloud.aiplatform_v1beta1.services.model_service import (
client as model_service_client_v1beta1,
)
from google.cloud.aiplatform_v1beta1.services.pipeline_service import (
client as pipeline_service_client_v1beta1,
)
from google.cloud.aiplatform_v1beta1.services.persistent_resource_service import (
client as persistent_resource_service_client_v1beta1,
)
from google.cloud.aiplatform_v1beta1.services.pipeline_service import (
client as pipeline_service_client_v1beta1,
)
from google.cloud.aiplatform_v1beta1.services.prediction_service import (
client as prediction_service_client_v1beta1,
)
Expand All @@ -90,10 +90,20 @@
from google.cloud.aiplatform_v1beta1.services.tensorboard_service import (
client as tensorboard_service_client_v1beta1,
)
from google.cloud.aiplatform_v1beta1.services.vertex_rag_data_service import (
client as vertex_rag_data_service_client_v1beta1,
)
from google.cloud.aiplatform_v1beta1.services.vertex_rag_data_service import (
async_client as vertex_rag_data_service_async_client_v1beta1,
)
from google.cloud.aiplatform_v1beta1.services.vertex_rag_service import (
client as vertex_rag_service_client_v1beta1,
)
from google.cloud.aiplatform_v1beta1.services.vizier_service import (
client as vizier_service_client_v1beta1,
)


from google.cloud.aiplatform_v1.services.dataset_service import (
client as dataset_service_client_v1,
)
Expand Down Expand Up @@ -195,9 +205,14 @@
pipeline_service_client_v1beta1,
prediction_service_client_v1beta1,
prediction_service_async_client_v1beta1,
reasoning_engine_execution_service_client_v1beta1,
reasoning_engine_service_client_v1beta1,
schedule_service_client_v1beta1,
specialist_pool_service_client_v1beta1,
metadata_service_client_v1beta1,
tensorboard_service_client_v1beta1,
vertex_rag_service_client_v1beta1,
vertex_rag_data_service_client_v1beta1,
vertex_rag_data_service_async_client_v1beta1,
vizier_service_client_v1beta1,
)
36 changes: 36 additions & 0 deletions google/cloud/aiplatform/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,9 @@
persistent_resource_service_client_v1beta1,
reasoning_engine_service_client_v1beta1,
reasoning_engine_execution_service_client_v1beta1,
vertex_rag_data_service_async_client_v1beta1,
vertex_rag_data_service_client_v1beta1,
vertex_rag_service_client_v1beta1,
)
from google.cloud.aiplatform.compat.services import (
dataset_service_client_v1,
Expand Down Expand Up @@ -799,6 +802,39 @@ class ReasoningEngineExecutionClientWithOverride(ClientWithOverride):
)


class VertexRagDataClientWithOverride(ClientWithOverride):
_is_temporary = True
_default_version = compat.V1BETA1
_version_map = (
(
compat.V1BETA1,
vertex_rag_data_service_client_v1beta1.VertexRagDataServiceClient,
),
)


class VertexRagDataAsyncClientWithOverride(ClientWithOverride):
_is_temporary = True
_default_version = compat.V1BETA1
_version_map = (
(
compat.V1BETA1,
vertex_rag_data_service_async_client_v1beta1.VertexRagDataServiceAsyncClient,
),
)


class VertexRagClientWithOverride(ClientWithOverride):
_is_temporary = True
_default_version = compat.V1BETA1
_version_map = (
(
compat.V1BETA1,
vertex_rag_service_client_v1beta1.VertexRagServiceClient,
),
)


VertexAiServiceClientWithOverride = TypeVar(
"VertexAiServiceClientWithOverride",
DatasetClientWithOverride,
Expand Down
111 changes: 111 additions & 0 deletions tests/unit/vertex_rag/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
# -*- coding: utf-8 -*-

# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from unittest.mock import patch
from google import auth
from google.api_core import operation as ga_operation
from google.auth import credentials as auth_credentials
from vertexai.preview import rag
from google.cloud.aiplatform_v1beta1 import (
DeleteRagCorpusRequest,
VertexRagDataServiceAsyncClient,
VertexRagDataServiceClient,
)
import test_rag_constants as tc
import mock
import pytest


_TEST_CREDENTIALS = mock.Mock(spec=auth_credentials.AnonymousCredentials())


@pytest.fixture(scope="module")
def google_auth_mock():
with mock.patch.object(auth, "default") as auth_mock:
auth_mock.return_value = (
auth_credentials.AnonymousCredentials(),
tc.TEST_PROJECT,
)
yield auth_mock


@pytest.fixture
def authorized_session_mock():
with patch(
"google.auth.transport.requests.AuthorizedSession"
) as MockAuthorizedSession:
mock_auth_session = MockAuthorizedSession(_TEST_CREDENTIALS)
yield mock_auth_session


@pytest.fixture
def rag_data_client_mock():
with mock.patch.object(
rag.utils._gapic_utils, "create_rag_data_service_client"
) as rag_data_client_mock:
api_client_mock = mock.Mock(spec=VertexRagDataServiceClient)

# get_rag_corpus
api_client_mock.get_rag_corpus.return_value = tc.TEST_GAPIC_RAG_CORPUS
# delete_rag_corpus
delete_rag_corpus_lro_mock = mock.Mock(ga_operation.Operation)
delete_rag_corpus_lro_mock.result.return_value = DeleteRagCorpusRequest()
api_client_mock.delete_rag_corpus.return_value = delete_rag_corpus_lro_mock
# get_rag_file
api_client_mock.get_rag_file.return_value = tc.TEST_GAPIC_RAG_FILE

rag_data_client_mock.return_value = api_client_mock
yield rag_data_client_mock


@pytest.fixture
def rag_data_client_mock_exception():
with mock.patch.object(
rag.utils._gapic_utils, "create_rag_data_service_client"
) as rag_data_client_mock_exception:
api_client_mock = mock.Mock(spec=VertexRagDataServiceClient)
# create_rag_corpus
api_client_mock.create_rag_corpus.side_effect = Exception
# get_rag_corpus
api_client_mock.get_rag_corpus.side_effect = Exception
# list_rag_corpora
api_client_mock.list_rag_corpora.side_effect = Exception
# delete_rag_corpus
api_client_mock.delete_rag_corpus.side_effect = Exception
# upload_rag_file
api_client_mock.upload_rag_file.side_effect = Exception
# import_rag_files
api_client_mock.import_rag_files.side_effect = Exception
# get_rag_file
api_client_mock.get_rag_file.side_effect = Exception
# list_rag_files
api_client_mock.list_rag_files.side_effect = Exception
# delete_rag_file
api_client_mock.delete_rag_file.side_effect = Exception
rag_data_client_mock_exception.return_value = api_client_mock
yield rag_data_client_mock_exception


@pytest.fixture
def rag_data_async_client_mock_exception():
with mock.patch.object(
rag.utils._gapic_utils, "create_rag_data_service_async_client"
) as rag_data_async_client_mock_exception:
api_client_mock = mock.Mock(spec=VertexRagDataServiceAsyncClient)
# import_rag_files
api_client_mock.import_rag_files.side_effect = Exception
rag_data_client_mock_exception.return_value = api_client_mock
yield rag_data_async_client_mock_exception
148 changes: 148 additions & 0 deletions tests/unit/vertex_rag/test_rag_constants.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
# -*- coding: utf-8 -*-

# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

from vertexai.preview.rag.utils.resources import (
RagCorpus,
RagFile,
)
from google.cloud import aiplatform
from google.cloud.aiplatform_v1beta1 import (
GoogleDriveSource,
RagFileChunkingConfig,
ImportRagFilesConfig,
ImportRagFilesRequest,
ImportRagFilesResponse,
RagCorpus as GapicRagCorpus,
RagFile as GapicRagFile,
RagContexts,
RetrieveContextsResponse,
)


TEST_PROJECT = "test-project"
TEST_PROJECT_NUMBER = "12345678"
TEST_REGION = "us-central1"
TEST_CORPUS_DISPLAY_NAME = "my-corpus-1"
TEST_CORPUS_DISCRIPTION = "My first corpus."
TEST_RAG_CORPUS_ID = "generate-123"
TEST_API_ENDPOINT = "us-central1-" + aiplatform.constants.base.API_BASE_PATH
TEST_RAG_CORPUS_RESOURCE_NAME = f"projects/{TEST_PROJECT_NUMBER}/locations/{TEST_REGION}/ragCorpora/{TEST_RAG_CORPUS_ID}"

# RagCorpus
TEST_GAPIC_RAG_CORPUS = GapicRagCorpus(
name=TEST_RAG_CORPUS_RESOURCE_NAME,
display_name=TEST_CORPUS_DISPLAY_NAME,
description=TEST_CORPUS_DISCRIPTION,
)
TEST_RAG_CORPUS = RagCorpus(
name=TEST_RAG_CORPUS_RESOURCE_NAME,
display_name=TEST_CORPUS_DISPLAY_NAME,
description=TEST_CORPUS_DISCRIPTION,
)
TEST_PAGE_TOKEN = "test-page-token"

# RagFiles
TEST_PATH = "usr/home/my_file.txt"
TEST_GCS_PATH = "gs://usr/home/data_dir/"
TEST_FILE_DISPLAY_NAME = "my-file.txt"
TEST_FILE_DESCRIPTION = "my file."
TEST_HEADERS = {"X-Goog-Upload-Protocol": "multipart"}
TEST_UPLOAD_REQUEST_URI = "https://{}/upload/v1beta1/projects/{}/locations/{}/ragCorpora/{}/ragFiles:upload".format(
TEST_API_ENDPOINT, TEST_PROJECT_NUMBER, TEST_REGION, TEST_RAG_CORPUS_ID
)
TEST_RAG_FILE_ID = "generate-456"
TEST_RAG_FILE_RESOURCE_NAME = (
TEST_RAG_CORPUS_RESOURCE_NAME + f"/ragFiles/{TEST_RAG_FILE_ID}"
)
TEST_UPLOAD_RAG_FILE_RESPONSE_CONTENT = ""
TEST_RAG_FILE_JSON = {
"ragFile": {
"name": TEST_RAG_FILE_RESOURCE_NAME,
"displayName": TEST_FILE_DISPLAY_NAME,
}
}
TEST_RAG_FILE_JSON_ERROR = {"error": {"code": 13}}
TEST_CHUNK_SIZE = 512
TEST_CHUNK_OVERLAP = 100
# GCS
TEST_IMPORT_FILES_CONFIG_GCS = ImportRagFilesConfig()
TEST_IMPORT_FILES_CONFIG_GCS.gcs_source.uris = [TEST_GCS_PATH]
TEST_IMPORT_REQUEST_GCS = ImportRagFilesRequest(
parent=TEST_RAG_CORPUS_RESOURCE_NAME,
import_rag_files_config=TEST_IMPORT_FILES_CONFIG_GCS,
)
# Google Drive folders
TEST_DRIVE_FOLDER_ID = "123"
TEST_DRIVE_FOLDER = (
f"https://drive.google.com/corp/drive/folders/{TEST_DRIVE_FOLDER_ID}"
)
TEST_IMPORT_FILES_CONFIG_DRIVE_FOLDER = ImportRagFilesConfig()
TEST_IMPORT_FILES_CONFIG_DRIVE_FOLDER.google_drive_source.resource_ids = [
GoogleDriveSource.ResourceId(
resource_id=TEST_DRIVE_FOLDER_ID,
resource_type=GoogleDriveSource.ResourceId.ResourceType.RESOURCE_TYPE_FOLDER,
)
]
TEST_IMPORT_REQUEST_DRIVE_FOLDER = ImportRagFilesRequest(
parent=TEST_RAG_CORPUS_RESOURCE_NAME,
import_rag_files_config=TEST_IMPORT_FILES_CONFIG_DRIVE_FOLDER,
)
# Google Drive files
TEST_DRIVE_FILE_ID = "456"
TEST_DRIVE_FILE = f"https://drive.google.com/file/d/{TEST_DRIVE_FILE_ID}"
TEST_IMPORT_FILES_CONFIG_DRIVE_FILE = ImportRagFilesConfig(
rag_file_chunking_config=RagFileChunkingConfig(
chunk_size=TEST_CHUNK_SIZE,
chunk_overlap=TEST_CHUNK_OVERLAP,
)
)
TEST_IMPORT_FILES_CONFIG_DRIVE_FILE.google_drive_source.resource_ids = [
GoogleDriveSource.ResourceId(
resource_id=TEST_DRIVE_FILE_ID,
resource_type=GoogleDriveSource.ResourceId.ResourceType.RESOURCE_TYPE_FILE,
)
]
TEST_IMPORT_REQUEST_DRIVE_FILE = ImportRagFilesRequest(
parent=TEST_RAG_CORPUS_RESOURCE_NAME,
import_rag_files_config=TEST_IMPORT_FILES_CONFIG_DRIVE_FILE,
)

TEST_IMPORT_RESPONSE = ImportRagFilesResponse(imported_rag_files_count=2)

TEST_GAPIC_RAG_FILE = GapicRagFile(
name=TEST_RAG_FILE_RESOURCE_NAME,
display_name=TEST_FILE_DISPLAY_NAME,
description=TEST_FILE_DESCRIPTION,
)
TEST_RAG_FILE = RagFile(
name=TEST_RAG_FILE_RESOURCE_NAME,
display_name=TEST_FILE_DISPLAY_NAME,
description=TEST_FILE_DESCRIPTION,
)

# Retrieval
TEST_QUERY_TEXT = "What happen to the fox and the dog?"
TEST_CONTEXTS = RagContexts(
contexts=[
RagContexts.Context(
source_uri="https://drive.google.com/file/d/123/view?usp=drivesdk",
text="The quick brown fox jumps over the lazy dog.",
),
RagContexts.Context(text="The slow red fox jumps over the lazy dog."),
]
)
TEST_RETRIEVAL_RESPONSE = RetrieveContextsResponse(contexts=TEST_CONTEXTS)

0 comments on commit 39b5149

Please sign in to comment.