Skip to content

Commit

Permalink
feat: LLM - Released TextGenerationModel tuning to GA
Browse files Browse the repository at this point in the history
Changes from Preview:

* `tune_model()` no longer blocks while the model is being tuned.
* `tune_model()` no longer updates the model in-place once tuning has fnished. Instead, it returns a job that can be used to get the newly tuned model.
* `tune_model()` now returns a tuning job object. The `tuning_job.tuned_model` property can be used to get the tuned model, waiting for the tuning to finish if needed. (This is also working in Preview)
* The `learning_rate` parameter has been removed. Use `learning_rate_multiplier` instead.
* The default value for `train_steps` has changed from 1000 to the tuning pipeline default value (usually 300).

PiperOrigin-RevId: 558650122
  • Loading branch information
Ark-kun authored and Copybara-Service committed Aug 21, 2023
1 parent 6f7ea84 commit 62ff30d
Show file tree
Hide file tree
Showing 3 changed files with 84 additions and 2 deletions.
2 changes: 1 addition & 1 deletion tests/system/aiplatform/test_language_models.py
Expand Up @@ -165,7 +165,7 @@ def test_tuning(self, shared_state):
"""Test tuning, listing and loading models."""
aiplatform.init(project=e2e_base._PROJECT, location=e2e_base._LOCATION)

model = TextGenerationModel.from_pretrained("google/text-bison@001")
model = language_models.TextGenerationModel.from_pretrained("text-bison@001")

import pandas

Expand Down
80 changes: 80 additions & 0 deletions tests/unit/aiplatform/test_language_models.py
Expand Up @@ -1384,6 +1384,86 @@ def test_tune_text_generation_model(
== test_constants.EndpointConstants._TEST_ENDPOINT_NAME
)

@pytest.mark.parametrize(
"job_spec",
[_TEST_PIPELINE_SPEC_JSON, _TEST_PIPELINE_JOB],
)
@pytest.mark.parametrize(
"mock_request_urlopen",
["https://us-central1-kfp.pkg.dev/proj/repo/pack/latest"],
indirect=True,
)
def test_tune_text_generation_model_ga(
self,
mock_pipeline_service_create,
mock_pipeline_job_get,
mock_pipeline_bucket_exists,
job_spec,
mock_load_yaml_and_json,
mock_gcs_from_string,
mock_gcs_upload,
mock_request_urlopen,
mock_get_tuned_model,
):
"""Tests tuning the text generation model."""
aiplatform.init(
project=_TEST_PROJECT,
location=_TEST_LOCATION,
encryption_spec_key_name=_TEST_ENCRYPTION_KEY_NAME,
)
with mock.patch.object(
target=model_garden_service_client.ModelGardenServiceClient,
attribute="get_publisher_model",
return_value=gca_publisher_model.PublisherModel(
_TEXT_BISON_PUBLISHER_MODEL_DICT
),
):
model = language_models.TextGenerationModel.from_pretrained(
"text-bison@001"
)

tuning_job_location = "europe-west4"
evaluation_data_uri = "gs://bucket/eval.jsonl"
evaluation_interval = 37
enable_early_stopping = True
tensorboard_name = f"projects/{_TEST_PROJECT}/locations/{tuning_job_location}/tensorboards/123"

tuning_job = model.tune_model(
training_data=_TEST_TEXT_BISON_TRAINING_DF,
tuning_job_location=tuning_job_location,
tuned_model_location="us-central1",
learning_rate_multiplier=2.0,
train_steps=10,
tuning_evaluation_spec=preview_language_models.TuningEvaluationSpec(
evaluation_data=evaluation_data_uri,
evaluation_interval=evaluation_interval,
enable_early_stopping=enable_early_stopping,
tensorboard=tensorboard_name,
),
)
call_kwargs = mock_pipeline_service_create.call_args[1]
pipeline_arguments = call_kwargs[
"pipeline_job"
].runtime_config.parameter_values
assert pipeline_arguments["learning_rate_multiplier"] == 2.0
assert pipeline_arguments["train_steps"] == 10
assert pipeline_arguments["evaluation_data_uri"] == evaluation_data_uri
assert pipeline_arguments["evaluation_interval"] == evaluation_interval
assert pipeline_arguments["enable_early_stopping"] == enable_early_stopping
assert pipeline_arguments["tensorboard_resource_id"] == tensorboard_name
assert pipeline_arguments["large_model_reference"] == "text-bison@001"
assert (
call_kwargs["pipeline_job"].encryption_spec.kms_key_name
== _TEST_ENCRYPTION_KEY_NAME
)

# Testing the tuned model
tuned_model = tuning_job.get_tuned_model()
assert (
tuned_model._endpoint_name
== test_constants.EndpointConstants._TEST_ENDPOINT_NAME
)

@pytest.mark.parametrize(
"job_spec",
[_TEST_PIPELINE_SPEC_JSON],
Expand Down
4 changes: 3 additions & 1 deletion vertexai/language_models/_language_models.py
Expand Up @@ -894,7 +894,9 @@ def batch_predict(
)


class TextGenerationModel(_TextGenerationModel, _ModelWithBatchPredict):
class TextGenerationModel(
_TextGenerationModel, _TunableTextModelMixin, _ModelWithBatchPredict
):
pass


Expand Down

0 comments on commit 62ff30d

Please sign in to comment.