diff --git a/tests/unit/aiplatform/test_language_models.py b/tests/unit/aiplatform/test_language_models.py index c4ce74d9fa..0d696added 100644 --- a/tests/unit/aiplatform/test_language_models.py +++ b/tests/unit/aiplatform/test_language_models.py @@ -2125,12 +2125,18 @@ def test_tune_chat_model( ): model = language_models.ChatModel.from_pretrained("chat-bison@001") + tuning_job_location = "europe-west4" + tensorboard_name = f"projects/{_TEST_PROJECT}/locations/{tuning_job_location}/tensorboards/123" + default_context = "Default context" tuning_job = model.tune_model( training_data=_TEST_TEXT_BISON_TRAINING_DF, tuning_job_location="europe-west4", tuned_model_location="us-central1", default_context=default_context, + tuning_evaluation_spec=preview_language_models.TuningEvaluationSpec( + tensorboard=tensorboard_name, + ), accelerator_type="TPU", ) call_kwargs = mock_pipeline_service_create.call_args[1] @@ -2140,6 +2146,7 @@ def test_tune_chat_model( assert pipeline_arguments["large_model_reference"] == "chat-bison@001" assert pipeline_arguments["default_context"] == default_context assert pipeline_arguments["accelerator_type"] == "TPU" + assert pipeline_arguments["tensorboard_resource_id"] == tensorboard_name # Testing the tuned model tuned_model = tuning_job.get_tuned_model() @@ -2148,6 +2155,26 @@ def test_tune_chat_model( == test_constants.EndpointConstants._TEST_ENDPOINT_NAME ) + unsupported_tuning_evaluation_spec_att = ( + {"evaluation_data": "gs://bucket/eval.jsonl"}, + {"evaluation_interval": 37}, + {"enable_early_stopping": True}, + {"enable_checkpoint_selection": True}, + ) + for unsupported_att in unsupported_tuning_evaluation_spec_att: + unsupported_tuning_evaluation_spec = ( + preview_language_models.TuningEvaluationSpec(**unsupported_att) + ) + with pytest.raises(AttributeError): + model.tune_model( + training_data=_TEST_TEXT_BISON_TRAINING_DF, + tuning_job_location="europe-west4", + tuned_model_location="us-central1", + default_context=default_context, + tuning_evaluation_spec=unsupported_tuning_evaluation_spec, + accelerator_type="TPU", + ) + @pytest.mark.parametrize( "job_spec", [_TEST_PIPELINE_SPEC_JSON], @@ -2228,12 +2255,18 @@ def test_tune_code_chat_model( ): model = language_models.CodeChatModel.from_pretrained("codechat-bison@001") + tuning_job_location = "europe-west4" + tensorboard_name = f"projects/{_TEST_PROJECT}/locations/{tuning_job_location}/tensorboards/123" + # The tune_model call needs to be inside the PublisherModel mock # since it gets a new PublisherModel when tuning completes. model.tune_model( training_data=_TEST_TEXT_BISON_TRAINING_DF, tuning_job_location="europe-west4", tuned_model_location="us-central1", + tuning_evaluation_spec=preview_language_models.TuningEvaluationSpec( + tensorboard=tensorboard_name, + ), accelerator_type="TPU", ) call_kwargs = mock_pipeline_service_create.call_args[1] @@ -2242,6 +2275,26 @@ def test_tune_code_chat_model( ].runtime_config.parameter_values assert pipeline_arguments["large_model_reference"] == "codechat-bison@001" assert pipeline_arguments["accelerator_type"] == "TPU" + assert pipeline_arguments["tensorboard_resource_id"] == tensorboard_name + + unsupported_tuning_evaluation_spec_att = ( + {"evaluation_data": "gs://bucket/eval.jsonl"}, + {"evaluation_interval": 37}, + {"enable_early_stopping": True}, + {"enable_checkpoint_selection": True}, + ) + for unsupported_att in unsupported_tuning_evaluation_spec_att: + unsupported_tuning_evaluation_spec = ( + preview_language_models.TuningEvaluationSpec(**unsupported_att) + ) + with pytest.raises(AttributeError): + model.tune_model( + training_data=_TEST_TEXT_BISON_TRAINING_DF, + tuning_job_location="europe-west4", + tuned_model_location="us-central1", + tuning_evaluation_spec=unsupported_tuning_evaluation_spec, + accelerator_type="TPU", + ) @pytest.mark.usefixtures( "get_model_with_tuned_version_label_mock", diff --git a/vertexai/language_models/_language_models.py b/vertexai/language_models/_language_models.py index 7e061431f5..956d1cb7df 100644 --- a/vertexai/language_models/_language_models.py +++ b/vertexai/language_models/_language_models.py @@ -496,6 +496,7 @@ def tune_model( model_display_name: Optional[str] = None, default_context: Optional[str] = None, accelerator_type: Optional[_ACCELERATOR_TYPE_TYPE] = None, + tuning_evaluation_spec: Optional["TuningEvaluationSpec"] = None, ) -> "_LanguageModelTuningJob": """Tunes a model based on training data. @@ -520,6 +521,7 @@ def tune_model( model_display_name: Custom display name for the tuned model. default_context: The context to use for all training samples by default. accelerator_type: Type of accelerator to use. Can be "TPU" or "GPU". + tuning_evaluation_spec: Specification for the model evaluation during tuning. Returns: A `LanguageModelTuningJob` object that represents the tuning job. @@ -529,8 +531,25 @@ def tune_model( ValueError: If the "tuning_job_location" value is not supported ValueError: If the "tuned_model_location" value is not supported RuntimeError: If the model does not support tuning + AttributeError: If any attribute in the "tuning_evaluation_spec" is not supported """ - # Note: Chat models do not support tuning_evaluation_spec + + if tuning_evaluation_spec is not None: + unsupported_chat_model_tuning_eval_spec = { + "evaluation_data": tuning_evaluation_spec.evaluation_data, + "evaluation_interval": tuning_evaluation_spec.evaluation_interval, + "enable_early_stopping": tuning_evaluation_spec.enable_early_stopping, + "enable_checkpoint_selection": tuning_evaluation_spec.enable_checkpoint_selection, + } + + for att_name, att_value in unsupported_chat_model_tuning_eval_spec.items(): + if not att_value is None: + raise AttributeError( + ( + f"ChatModel and CodeChatModel only support tensorboard as attribute for TuningEvaluationSpec" + f"found attribute name {att_name} with value {att_value}, please leave {att_name} to None" + ) + ) return super().tune_model( training_data=training_data, train_steps=train_steps, @@ -540,6 +559,7 @@ def tune_model( model_display_name=model_display_name, default_context=default_context, accelerator_type=accelerator_type, + tuning_evaluation_spec=tuning_evaluation_spec, )