Skip to content

Commit

Permalink
feat: LLM - Text Embedding - Added validation for text embedding tuni…
Browse files Browse the repository at this point in the history
…ng parameters.

PiperOrigin-RevId: 632301450
  • Loading branch information
vertex-sdk-bot authored and copybara-github committed May 9, 2024
1 parent cb8b10f commit 5a300c1
Show file tree
Hide file tree
Showing 2 changed files with 83 additions and 3 deletions.
52 changes: 52 additions & 0 deletions tests/unit/aiplatform/test_language_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -2407,6 +2407,58 @@ def test_tune_text_embedding_model(
== test_constants.EndpointConstants._TEST_ENDPOINT_NAME
)

@pytest.mark.parametrize(
"optional_tune_args,error_regex",
[
(
dict(test_data="/tmp/bucket/test.tsv"),
"Each tuning dataset file must be a Google Cloud Storage URI starting with 'gs://'.",
),
(
dict(output_dimensionality=-1),
"output_dimensionality must be an integer between 1 and 768",
),
(
dict(learning_rate_multiplier=0),
"learning_rate_multiplier must be greater than 0",
),
(
dict(train_steps=29),
"train_steps must be greater than or equal to 30",
),
(
dict(batch_size=2048),
"batch_size must be between 1 and 1024",
),
],
)
def test_tune_text_embedding_model_invalid_values(
self, optional_tune_args, error_regex
):
"""Tests that certain embedding tuning values fail validation."""
aiplatform.init(
project=_TEST_PROJECT,
location=_TEST_LOCATION,
encryption_spec_key_name=_TEST_ENCRYPTION_KEY_NAME,
)
with mock.patch.object(
target=model_garden_service_client.ModelGardenServiceClient,
attribute="get_publisher_model",
return_value=gca_publisher_model.PublisherModel(
_TEXT_GECKO_PUBLISHER_MODEL_DICT
),
):
model = preview_language_models.TextEmbeddingModel.from_pretrained(
"text-multilingual-embedding-002"
)
with pytest.raises(ValueError, match=error_regex):
model.tune_model(
training_data="gs://bucket/training.tsv",
corpus_data="gs://bucket/corpus.jsonl",
queries_data="gs://bucket/queries.jsonl",
**optional_tune_args,
)

@pytest.mark.parametrize(
"job_spec",
[_TEST_PIPELINE_SPEC_JSON, _TEST_PIPELINE_JOB],
Expand Down
34 changes: 31 additions & 3 deletions vertexai/language_models/_language_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -2192,8 +2192,6 @@ async def get_embeddings_async(

# TODO(b/625884109): Support Union[str, "pandas.core.frame.DataFrame"]
# for corpus, queries, test and validation data.
# TODO(b/625884109): Validate input args, batch_size >0 and train_steps >30, and
# task_type must be 'DEFAULT' or None if _model_id is textembedding-gecko@001.
class _PreviewTunableTextEmbeddingModelMixin(_TunableModelMixin):
@classmethod
def get_tuned_model(cls, *args, **kwargs):
Expand Down Expand Up @@ -2265,9 +2263,39 @@ def tune_model(
Calling `job.result()` blocks until the tuning is complete and returns a `LanguageModel` object.
Raises:
ValueError: If the "tuned_model_location" value is not supported
ValueError: If the provided parameter combinations or values are not
supported.
RuntimeError: If the model does not support tuning
"""
if batch_size is not None and batch_size not in range(1, 1024):
raise ValueError(
f"batch_size must be between 1 and 1024. Given {batch_size}."
)
if train_steps is not None and train_steps < 30:
raise ValueError(
f"train_steps must be greater than or equal to 30. Given {train_steps}."
)
if learning_rate_multiplier is not None and learning_rate_multiplier <= 0:
raise ValueError(
f"learning_rate_multiplier must be greater than 0. Given {learning_rate_multiplier}."
)
if output_dimensionality is not None and output_dimensionality not in range(
1, 769
):
raise ValueError(
f"output_dimensionality must be an integer between 1 and 768. Given {output_dimensionality}."
)
for dataset in [
training_data,
corpus_data,
queries_data,
test_data,
validation_data,
]:
if dataset is not None and not dataset.startswith("gs://"):
raise ValueError(
f"Each tuning dataset file must be a Google Cloud Storage URI starting with 'gs://'. Given {dataset}."
)

return super().tune_model(
training_data=training_data,
Expand Down

0 comments on commit 5a300c1

Please sign in to comment.