diff --git a/tests/system/aiplatform/test_language_models.py b/tests/system/aiplatform/test_language_models.py index d4e205c944..1a281d671c 100644 --- a/tests/system/aiplatform/test_language_models.py +++ b/tests/system/aiplatform/test_language_models.py @@ -19,7 +19,7 @@ from google.cloud import aiplatform from google.cloud.aiplatform.compat.types import ( - job_state_v1beta1 as gca_job_state_v1beta1, + job_state as gca_job_state, ) from tests.system.aiplatform import e2e_base from vertexai.preview.language_models import ( @@ -160,7 +160,7 @@ def test_tuning(self, shared_state): ) assert tuned_model_response.text - def test_batch_prediction(self): + def test_batch_prediction_for_text_generation(self): source_uri = "gs://ucaip-samples-us-central1/model/llm/batch_prediction/batch_prediction_prompts1.jsonl" destination_uri_prefix = "gs://ucaip-samples-us-central1/model/llm/batch_prediction/predictions/text-bison@001_" @@ -178,4 +178,24 @@ def test_batch_prediction(self): gapic_job = job._gca_resource job.delete() - assert gapic_job.state == gca_job_state_v1beta1.JobState.JOB_STATE_SUCCEEDED + assert gapic_job.state == gca_job_state.JobState.JOB_STATE_SUCCEEDED + + def test_batch_prediction_for_textembedding(self): + source_uri = "gs://ucaip-samples-us-central1/model/llm/batch_prediction/batch_prediction_prompts1.jsonl" + destination_uri_prefix = "gs://ucaip-samples-us-central1/model/llm/batch_prediction/predictions/textembedding-gecko@001_" + + aiplatform.init(project=e2e_base._PROJECT, location=e2e_base._LOCATION) + + model = TextEmbeddingModel.from_pretrained("textembedding-gecko") + job = model.batch_predict( + dataset=source_uri, + destination_uri_prefix=destination_uri_prefix, + model_parameters={}, + ) + + job.wait_for_resource_creation() + job.wait() + gapic_job = job._gca_resource + job.delete() + + assert gapic_job.state == gca_job_state.JobState.JOB_STATE_SUCCEEDED diff --git a/tests/unit/aiplatform/test_language_models.py b/tests/unit/aiplatform/test_language_models.py index edb615a50c..7da36fbea3 100644 --- a/tests/unit/aiplatform/test_language_models.py +++ b/tests/unit/aiplatform/test_language_models.py @@ -141,7 +141,7 @@ "version_id": "001", "open_source_category": "PROPRIETARY", "launch_stage": gca_publisher_model.PublisherModel.LaunchStage.GA, - "publisher_model_template": "projects/{user-project}/locations/{location}/publishers/google/models/chat-bison@001", + "publisher_model_template": "projects/{user-project}/locations/{location}/publishers/google/models/textembedding-gecko@001", "predict_schemata": { "instance_schema_uri": "gs://google-cloud-aiplatform/schema/predict/instance/text_embedding_1.0.0.yaml", "parameters_schema_uri": "gs://google-cloud-aiplatfrom/schema/predict/params/text_generation_1.0.0.yaml", @@ -1323,3 +1323,37 @@ def test_batch_prediction(self): gcs_destination_prefix="gs://test-bucket/results/", model_parameters={"temperature": 0.1}, ) + + def test_batch_prediction_for_text_embedding(self): + """Tests batch prediction.""" + aiplatform.init( + project=_TEST_PROJECT, + location=_TEST_LOCATION, + ) + with mock.patch.object( + target=model_garden_service_client.ModelGardenServiceClient, + attribute="get_publisher_model", + return_value=gca_publisher_model.PublisherModel( + _TEXT_EMBEDDING_GECKO_PUBLISHER_MODEL_DICT + ), + ): + model = preview_language_models.TextEmbeddingModel.from_pretrained( + "textembedding-gecko@001" + ) + + with mock.patch.object( + target=aiplatform.BatchPredictionJob, + attribute="create", + ) as mock_create: + model.batch_predict( + dataset="gs://test-bucket/test_table.jsonl", + destination_uri_prefix="gs://test-bucket/results/", + model_parameters={}, + ) + mock_create.assert_called_once_with( + model_name="publishers/google/models/textembedding-gecko@001", + job_display_name=None, + gcs_source="gs://test-bucket/test_table.jsonl", + gcs_destination_prefix="gs://test-bucket/results/", + model_parameters={}, + ) diff --git a/vertexai/language_models/_language_models.py b/vertexai/language_models/_language_models.py index b5899c0761..378006a891 100644 --- a/vertexai/language_models/_language_models.py +++ b/vertexai/language_models/_language_models.py @@ -586,9 +586,7 @@ def get_embeddings(self, texts: List[str]) -> List["TextEmbedding"]: ] -class _PreviewTextEmbeddingModel(TextEmbeddingModel): - """Preview text embedding model.""" - +class _PreviewTextEmbeddingModel(TextEmbeddingModel, _ModelWithBatchPredict): _LAUNCH_STAGE = _model_garden_models._SDK_PUBLIC_PREVIEW_LAUNCH_STAGE