Skip to content

Commit

Permalink
feat: LLM - Support for Batch Prediction for the textembedding mode…
Browse files Browse the repository at this point in the history
…ls (preview)

Usage:
```
model = TextEmbeddingModel.from_pretrained("textembedding-gecko@001")
job = model.batch_predict(
    dataset="gs://<bicket>/dataset.jsonl",
    destination_uri_prefix="gs://<bicket>/batch_prediction/",
    # Optional:
    model_parameters={},
)
```
PiperOrigin-RevId: 551663844
  • Loading branch information
Ark-kun authored and Copybara-Service committed Jul 27, 2023
1 parent 7d72bd1 commit a368538
Show file tree
Hide file tree
Showing 3 changed files with 59 additions and 7 deletions.
26 changes: 23 additions & 3 deletions tests/system/aiplatform/test_language_models.py
Expand Up @@ -19,7 +19,7 @@

from google.cloud import aiplatform
from google.cloud.aiplatform.compat.types import (
job_state_v1beta1 as gca_job_state_v1beta1,
job_state as gca_job_state,
)
from tests.system.aiplatform import e2e_base
from vertexai.preview.language_models import (
Expand Down Expand Up @@ -160,7 +160,7 @@ def test_tuning(self, shared_state):
)
assert tuned_model_response.text

def test_batch_prediction(self):
def test_batch_prediction_for_text_generation(self):
source_uri = "gs://ucaip-samples-us-central1/model/llm/batch_prediction/batch_prediction_prompts1.jsonl"
destination_uri_prefix = "gs://ucaip-samples-us-central1/model/llm/batch_prediction/predictions/text-bison@001_"

Expand All @@ -178,4 +178,24 @@ def test_batch_prediction(self):
gapic_job = job._gca_resource
job.delete()

assert gapic_job.state == gca_job_state_v1beta1.JobState.JOB_STATE_SUCCEEDED
assert gapic_job.state == gca_job_state.JobState.JOB_STATE_SUCCEEDED

def test_batch_prediction_for_textembedding(self):
source_uri = "gs://ucaip-samples-us-central1/model/llm/batch_prediction/batch_prediction_prompts1.jsonl"
destination_uri_prefix = "gs://ucaip-samples-us-central1/model/llm/batch_prediction/predictions/textembedding-gecko@001_"

aiplatform.init(project=e2e_base._PROJECT, location=e2e_base._LOCATION)

model = TextEmbeddingModel.from_pretrained("textembedding-gecko")
job = model.batch_predict(
dataset=source_uri,
destination_uri_prefix=destination_uri_prefix,
model_parameters={},
)

job.wait_for_resource_creation()
job.wait()
gapic_job = job._gca_resource
job.delete()

assert gapic_job.state == gca_job_state.JobState.JOB_STATE_SUCCEEDED
36 changes: 35 additions & 1 deletion tests/unit/aiplatform/test_language_models.py
Expand Up @@ -141,7 +141,7 @@
"version_id": "001",
"open_source_category": "PROPRIETARY",
"launch_stage": gca_publisher_model.PublisherModel.LaunchStage.GA,
"publisher_model_template": "projects/{user-project}/locations/{location}/publishers/google/models/chat-bison@001",
"publisher_model_template": "projects/{user-project}/locations/{location}/publishers/google/models/textembedding-gecko@001",
"predict_schemata": {
"instance_schema_uri": "gs://google-cloud-aiplatform/schema/predict/instance/text_embedding_1.0.0.yaml",
"parameters_schema_uri": "gs://google-cloud-aiplatfrom/schema/predict/params/text_generation_1.0.0.yaml",
Expand Down Expand Up @@ -1323,3 +1323,37 @@ def test_batch_prediction(self):
gcs_destination_prefix="gs://test-bucket/results/",
model_parameters={"temperature": 0.1},
)

def test_batch_prediction_for_text_embedding(self):
"""Tests batch prediction."""
aiplatform.init(
project=_TEST_PROJECT,
location=_TEST_LOCATION,
)
with mock.patch.object(
target=model_garden_service_client.ModelGardenServiceClient,
attribute="get_publisher_model",
return_value=gca_publisher_model.PublisherModel(
_TEXT_EMBEDDING_GECKO_PUBLISHER_MODEL_DICT
),
):
model = preview_language_models.TextEmbeddingModel.from_pretrained(
"textembedding-gecko@001"
)

with mock.patch.object(
target=aiplatform.BatchPredictionJob,
attribute="create",
) as mock_create:
model.batch_predict(
dataset="gs://test-bucket/test_table.jsonl",
destination_uri_prefix="gs://test-bucket/results/",
model_parameters={},
)
mock_create.assert_called_once_with(
model_name="publishers/google/models/textembedding-gecko@001",
job_display_name=None,
gcs_source="gs://test-bucket/test_table.jsonl",
gcs_destination_prefix="gs://test-bucket/results/",
model_parameters={},
)
4 changes: 1 addition & 3 deletions vertexai/language_models/_language_models.py
Expand Up @@ -586,9 +586,7 @@ def get_embeddings(self, texts: List[str]) -> List["TextEmbedding"]:
]


class _PreviewTextEmbeddingModel(TextEmbeddingModel):
"""Preview text embedding model."""

class _PreviewTextEmbeddingModel(TextEmbeddingModel, _ModelWithBatchPredict):
_LAUNCH_STAGE = _model_garden_models._SDK_PUBLIC_PREVIEW_LAUNCH_STAGE


Expand Down

0 comments on commit a368538

Please sign in to comment.