Skip to content

Commit

Permalink
feat: Add support for user-configurable 1P embedding models and quota…
Browse files Browse the repository at this point in the history
… for RAG

PiperOrigin-RevId: 642414350
  • Loading branch information
yinghsienwu authored and Copybara-Service committed Jun 11, 2024
1 parent cf8bc3d commit 8b3beb6
Show file tree
Hide file tree
Showing 6 changed files with 223 additions and 3 deletions.
12 changes: 12 additions & 0 deletions tests/unit/vertex_rag/test_rag_constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
#

from vertexai.preview.rag.utils.resources import (
EmbeddingModelConfig,
RagCorpus,
RagFile,
RagResource,
Expand Down Expand Up @@ -49,10 +50,19 @@
display_name=TEST_CORPUS_DISPLAY_NAME,
description=TEST_CORPUS_DISCRIPTION,
)
TEST_GAPIC_RAG_CORPUS.rag_embedding_model_config.vertex_prediction_endpoint.endpoint = (
"projects/{}/locations/{}/publishers/google/models/textembedding-gecko".format(
TEST_PROJECT, TEST_REGION
)
)
TEST_EMBEDDING_MODEL_CONFIG = EmbeddingModelConfig(
publisher_model="publishers/google/models/textembedding-gecko",
)
TEST_RAG_CORPUS = RagCorpus(
name=TEST_RAG_CORPUS_RESOURCE_NAME,
display_name=TEST_CORPUS_DISPLAY_NAME,
description=TEST_CORPUS_DISCRIPTION,
embedding_model_config=TEST_EMBEDDING_MODEL_CONFIG,
)
TEST_PAGE_TOKEN = "test-page-token"

Expand Down Expand Up @@ -114,6 +124,8 @@
chunk_overlap=TEST_CHUNK_OVERLAP,
)
)
TEST_IMPORT_FILES_CONFIG_DRIVE_FILE.max_embedding_requests_per_min = 800

TEST_IMPORT_FILES_CONFIG_DRIVE_FILE.google_drive_source.resource_ids = [
GoogleDriveSource.ResourceId(
resource_id=TEST_DRIVE_FILE_ID,
Expand Down
46 changes: 45 additions & 1 deletion tests/unit/vertex_rag/test_rag_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from vertexai.preview import rag
from vertexai.preview.rag.utils._gapic_utils import (
prepare_import_files_request,
set_embedding_model_config,
)
from google.cloud.aiplatform_v1beta1 import (
VertexRagDataServiceAsyncClient,
Expand Down Expand Up @@ -171,7 +172,10 @@ def teardown_method(self):

@pytest.mark.usefixtures("create_rag_corpus_mock")
def test_create_corpus_success(self):
rag_corpus = rag.create_corpus(display_name=tc.TEST_CORPUS_DISPLAY_NAME)
rag_corpus = rag.create_corpus(
display_name=tc.TEST_CORPUS_DISPLAY_NAME,
embedding_model_config=tc.TEST_EMBEDDING_MODEL_CONFIG,
)

rag_corpus_eq(rag_corpus, tc.TEST_RAG_CORPUS)

Expand Down Expand Up @@ -391,6 +395,7 @@ def test_prepare_import_files_request_drive_files(self):
paths=paths,
chunk_size=tc.TEST_CHUNK_SIZE,
chunk_overlap=tc.TEST_CHUNK_OVERLAP,
max_embedding_requests_per_min=800,
)
import_files_request_eq(request, tc.TEST_IMPORT_REQUEST_DRIVE_FILE)

Expand All @@ -415,3 +420,42 @@ def test_prepare_import_files_request_invalid_path(self):
chunk_overlap=tc.TEST_CHUNK_OVERLAP,
)
e.match("path must be a Google Cloud Storage uri or a Google Drive url")

def test_set_embedding_model_config_set_both_error(self):
embedding_model_config = rag.EmbeddingModelConfig(
publisher_model="whatever",
endpoint="whatever",
)
with pytest.raises(ValueError) as e:
set_embedding_model_config(
embedding_model_config,
tc.TEST_GAPIC_RAG_CORPUS,
)
e.match("publisher_model and endpoint cannot be set at the same time")

def test_set_embedding_model_config_not_set_error(self):
embedding_model_config = rag.EmbeddingModelConfig()
with pytest.raises(ValueError) as e:
set_embedding_model_config(
embedding_model_config,
tc.TEST_GAPIC_RAG_CORPUS,
)
e.match("At least one of publisher_model and endpoint must be set")

def test_set_embedding_model_config_wrong_publisher_model_format_error(self):
embedding_model_config = rag.EmbeddingModelConfig(publisher_model="whatever")
with pytest.raises(ValueError) as e:
set_embedding_model_config(
embedding_model_config,
tc.TEST_GAPIC_RAG_CORPUS,
)
e.match("publisher_model must be of the format ")

def test_set_embedding_model_config_wrong_endpoint_format_error(self):
embedding_model_config = rag.EmbeddingModelConfig(endpoint="whatever")
with pytest.raises(ValueError) as e:
set_embedding_model_config(
embedding_model_config,
tc.TEST_GAPIC_RAG_CORPUS,
)
e.match("endpoint must be of the format ")
2 changes: 2 additions & 0 deletions vertexai/preview/rag/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
VertexRagStore,
)
from vertexai.preview.rag.utils.resources import (
EmbeddingModelConfig,
RagResource,
)

Expand All @@ -53,6 +54,7 @@
"list_files",
"delete_file",
"retrieval_query",
"EmbeddingModelConfig",
"Retrieval",
"VertexRagStore",
"RagResource",
Expand Down
34 changes: 33 additions & 1 deletion vertexai/preview/rag/rag_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,13 +43,16 @@
_gapic_utils,
)
from vertexai.preview.rag.utils.resources import (
EmbeddingModelConfig,
RagCorpus,
RagFile,
)


def create_corpus(
display_name: Optional[str] = None, description: Optional[str] = None
display_name: Optional[str] = None,
description: Optional[str] = None,
embedding_model_config: Optional[EmbeddingModelConfig] = None,
) -> RagCorpus:
"""Creates a new RagCorpus resource.
Expand All @@ -69,6 +72,7 @@ def create_corpus(
the RagCorpus. The name can be up to 128 characters long and can
consist of any UTF-8 characters.
description: The description of the RagCorpus.
embedding_model_config: The embedding model config.
Returns:
RagCorpus.
Raises:
Expand All @@ -80,6 +84,12 @@ def create_corpus(
parent = initializer.global_config.common_location_path(project=None, location=None)

rag_corpus = GapicRagCorpus(display_name=display_name, description=description)
if embedding_model_config:
rag_corpus = _gapic_utils.set_embedding_model_config(
embedding_model_config,
rag_corpus,
)

request = CreateRagCorpusRequest(
parent=parent,
rag_corpus=rag_corpus,
Expand Down Expand Up @@ -264,6 +274,7 @@ def import_files(
chunk_size: int = 1024,
chunk_overlap: int = 200,
timeout: int = 600,
max_embedding_requests_per_min: int = 1000,
) -> ImportRagFilesResponse:
"""
Import files to an existing RagCorpus, wait until completion.
Expand Down Expand Up @@ -299,6 +310,15 @@ def import_files(
"https://drive.google.com/corp/drive/folders/...").
chunk_size: The size of the chunks.
chunk_overlap: The overlap between chunks.
max_embedding_requests_per_min:
Optional. The max number of queries per
minute that this job is allowed to make to the
embedding model specified on the corpus. This
value is specific to this job and not shared
across other import jobs. Consult the Quotas
page on the project to set an appropriate value
here. If unspecified, a default value of 1,000
QPM would be used.
timeout: Default is 600 seconds.
Returns:
ImportRagFilesResponse.
Expand All @@ -309,6 +329,7 @@ def import_files(
paths=paths,
chunk_size=chunk_size,
chunk_overlap=chunk_overlap,
max_embedding_requests_per_min=max_embedding_requests_per_min,
)
client = _gapic_utils.create_rag_data_service_client()
try:
Expand All @@ -324,6 +345,7 @@ async def import_files_async(
paths: Sequence[str],
chunk_size: int = 1024,
chunk_overlap: int = 200,
max_embedding_requests_per_min: int = 1000,
) -> operation_async.AsyncOperation:
"""
Import files to an existing RagCorpus asynchronously.
Expand Down Expand Up @@ -361,6 +383,15 @@ async def import_files_async(
"https://drive.google.com/corp/drive/folders/...").
chunk_size: The size of the chunks.
chunk_overlap: The overlap between chunks.
max_embedding_requests_per_min:
Optional. The max number of queries per
minute that this job is allowed to make to the
embedding model specified on the corpus. This
value is specific to this job and not shared
across other import jobs. Consult the Quotas
page on the project to set an appropriate value
here. If unspecified, a default value of 1,000
QPM would be used.
Returns:
operation_async.AsyncOperation.
"""
Expand All @@ -370,6 +401,7 @@ async def import_files_async(
paths=paths,
chunk_size=chunk_size,
chunk_overlap=chunk_overlap,
max_embedding_requests_per_min=max_embedding_requests_per_min,
)
async_client = _gapic_utils.create_rag_data_service_async_client()
try:
Expand Down
99 changes: 98 additions & 1 deletion vertexai/preview/rag/utils/_gapic_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import re
from typing import Any, Dict, Sequence, Union
from google.cloud.aiplatform_v1beta1 import (
RagEmbeddingModelConfig,
GoogleDriveSource,
ImportRagFilesConfig,
ImportRagFilesRequest,
Expand All @@ -31,6 +32,7 @@
VertexRagClientWithOverride,
)
from vertexai.preview.rag.utils.resources import (
EmbeddingModelConfig,
RagCorpus,
RagFile,
)
Expand All @@ -57,12 +59,43 @@ def create_rag_service_client():
)


def convert_gapic_to_embedding_model_config(
gapic_embedding_model_config: RagEmbeddingModelConfig,
) -> EmbeddingModelConfig:
"""Convert GapicRagEmbeddingModelConfig to EmbeddingModelConfig."""
embedding_model_config = EmbeddingModelConfig()
path = gapic_embedding_model_config.vertex_prediction_endpoint.endpoint
publisher_model = re.match(
r"^projects/(?P<project>.+?)/locations/(?P<location>.+?)/publishers/google/models/(?P<model_id>.+?)$",
path,
)
endpoint = re.match(
r"^projects/(?P<project>.+?)/locations/(?P<location>.+?)/endpoints/(?P<endpoint>.+?)$",
path,
)
if publisher_model:
embedding_model_config.publisher_model = path
if endpoint:
embedding_model_config.endpoint = path
embedding_model_config.model = (
gapic_embedding_model_config.vertex_prediction_endpoint.model
)
embedding_model_config.model_version_id = (
gapic_embedding_model_config.vertex_prediction_endpoint.model_version_id
)

return embedding_model_config


def convert_gapic_to_rag_corpus(gapic_rag_corpus: GapicRagCorpus) -> RagCorpus:
""" "Convert GapicRagCorpus to RagCorpus."""
rag_corpus = RagCorpus(
name=gapic_rag_corpus.name,
display_name=gapic_rag_corpus.display_name,
description=gapic_rag_corpus.description,
embedding_model_config=convert_gapic_to_embedding_model_config(
gapic_rag_corpus.rag_embedding_model_config
),
)
return rag_corpus

Expand Down Expand Up @@ -124,6 +157,7 @@ def prepare_import_files_request(
paths: Sequence[str],
chunk_size: int = 1024,
chunk_overlap: int = 200,
max_embedding_requests_per_min: int = 1000,
) -> ImportRagFilesRequest:
if len(corpus_name.split("/")) != 6:
raise ValueError(
Expand All @@ -135,7 +169,8 @@ def prepare_import_files_request(
chunk_overlap=chunk_overlap,
)
import_rag_files_config = ImportRagFilesConfig(
rag_file_chunking_config=rag_file_chunking_config
rag_file_chunking_config=rag_file_chunking_config,
max_embedding_requests_per_min=max_embedding_requests_per_min,
)

uris = []
Expand Down Expand Up @@ -204,3 +239,65 @@ def get_file_name(
raise ValueError(
"name must be of the format `projects/{project}/locations/{location}/ragCorpora/{rag_corpus}/ragFiles/{rag_file}` or `{rag_file}`"
)


def set_embedding_model_config(
embedding_model_config: EmbeddingModelConfig,
rag_corpus: GapicRagCorpus,
) -> GapicRagCorpus:
if embedding_model_config.publisher_model and embedding_model_config.endpoint:
raise ValueError("publisher_model and endpoint cannot be set at the same time.")
if (
not embedding_model_config.publisher_model
and not embedding_model_config.endpoint
):
raise ValueError("At least one of publisher_model and endpoint must be set.")
parent = initializer.global_config.common_location_path(project=None, location=None)

if embedding_model_config.publisher_model:
publisher_model = embedding_model_config.publisher_model
full_resource_name = re.match(
r"^projects/(?P<project>.+?)/locations/(?P<location>.+?)/publishers/google/models/(?P<model_id>.+?)$",
publisher_model,
)
resource_name = re.match(
r"^publishers/google/models/(?P<model_id>.+?)$",
publisher_model,
)
if full_resource_name:
rag_corpus.rag_embedding_model_config.vertex_prediction_endpoint.endpoint = (
publisher_model
)
elif resource_name:
rag_corpus.rag_embedding_model_config.vertex_prediction_endpoint.endpoint = (
parent + "/" + publisher_model
)
else:
raise ValueError(
"publisher_model must be of the format `projects/{project}/locations/{location}/publishers/google/models/{model_id}` or `publishers/google/models/{model_id}`"
)

if embedding_model_config.endpoint:
endpoint = embedding_model_config.endpoint
full_resource_name = re.match(
r"^projects/(?P<project>.+?)/locations/(?P<location>.+?)/endpoints/(?P<endpoint>.+?)$",
endpoint,
)
resource_name = re.match(
r"^endpoints/(?P<endpoint>.+?)$",
endpoint,
)
if full_resource_name:
rag_corpus.rag_embedding_model_config.vertex_prediction_endpoint.endpoint = (
endpoint
)
elif resource_name:
rag_corpus.rag_embedding_model_config.vertex_prediction_endpoint.endpoint = (
parent + "/" + endpoint
)
else:
raise ValueError(
"endpoint must be of the format `projects/{project}/locations/{location}/endpoints/{endpoint}` or `endpoints/{endpoint}`"
)

return rag_corpus
Loading

0 comments on commit 8b3beb6

Please sign in to comment.