diff --git a/google/cloud/aiplatform/pipeline_job_schedules.py b/google/cloud/aiplatform/pipeline_job_schedules.py index 04cda69c99..00f9b86746 100644 --- a/google/cloud/aiplatform/pipeline_job_schedules.py +++ b/google/cloud/aiplatform/pipeline_job_schedules.py @@ -102,6 +102,10 @@ def __init__( create_pipeline_job_request["pipeline_job"][ "template_uri" ] = pipeline_job._gca_resource.template_uri + if "labels" in pipeline_job._gca_resource: + create_pipeline_job_request["pipeline_job"][ + "labels" + ] = pipeline_job._gca_resource.labels pipeline_job_schedule_args = { "display_name": display_name, "create_pipeline_job_request": create_pipeline_job_request, diff --git a/tests/unit/aiplatform/test_pipeline_job_schedules.py b/tests/unit/aiplatform/test_pipeline_job_schedules.py index 4c4c6273ab..0021ded3ba 100644 --- a/tests/unit/aiplatform/test_pipeline_job_schedules.py +++ b/tests/unit/aiplatform/test_pipeline_job_schedules.py @@ -638,6 +638,90 @@ def test_call_schedule_service_create_uses_pipeline_job_project_location( assert pipeline_job_schedule.project == "managed-pipeline-test" assert pipeline_job_schedule.location == "europe-west4" + @pytest.mark.parametrize( + "job_spec", + [_TEST_PIPELINE_SPEC_JSON, _TEST_PIPELINE_SPEC_YAML, _TEST_PIPELINE_JOB], + ) + def test_call_schedule_service_create_uses_pipeline_job_labels( + self, + mock_schedule_service_create, + mock_pipeline_service_list, + mock_schedule_service_get, + mock_schedule_bucket_exists, + job_spec, + mock_load_yaml_and_json, + ): + """Creates a PipelineJobSchedule. + + Tests that PipelineJobs created through PipelineJobSchedule inherit the labels of the init PipelineJob. + """ + TEST_PIPELINE_JOB_LABELS = {"name": "test_xx"} + + aiplatform.init( + project=_TEST_PROJECT, + staging_bucket=_TEST_GCS_BUCKET_NAME, + location=_TEST_LOCATION, + credentials=_TEST_CREDENTIALS, + ) + + job = pipeline_jobs.PipelineJob( + display_name=_TEST_PIPELINE_JOB_DISPLAY_NAME, + template_path=_TEST_TEMPLATE_PATH, + parameter_values=_TEST_PIPELINE_PARAMETER_VALUES, + input_artifacts=_TEST_PIPELINE_INPUT_ARTIFACTS, + enable_caching=True, + labels=TEST_PIPELINE_JOB_LABELS, + ) + + pipeline_job_schedule = pipeline_job_schedules.PipelineJobSchedule( + pipeline_job=job, + display_name=_TEST_PIPELINE_JOB_SCHEDULE_DISPLAY_NAME, + ) + + pipeline_job_schedule.create( + cron=_TEST_PIPELINE_JOB_SCHEDULE_CRON, + max_concurrent_run_count=_TEST_PIPELINE_JOB_SCHEDULE_MAX_CONCURRENT_RUN_COUNT, + max_run_count=_TEST_PIPELINE_JOB_SCHEDULE_MAX_RUN_COUNT, + service_account=_TEST_SERVICE_ACCOUNT, + network=_TEST_NETWORK, + create_request_timeout=None, + ) + + expected_runtime_config_dict = { + "gcsOutputDirectory": _TEST_GCS_BUCKET_NAME, + "parameterValues": _TEST_PIPELINE_PARAMETER_VALUES, + "inputArtifacts": {"vertex_model": {"artifactId": "456"}}, + } + runtime_config = gca_pipeline_job.PipelineJob.RuntimeConfig()._pb + json_format.ParseDict(expected_runtime_config_dict, runtime_config) + + job_spec = yaml.safe_load(job_spec) + pipeline_spec = job_spec.get("pipelineSpec") or job_spec + + # Construct expected request + expected_gapic_pipeline_job_schedule = gca_schedule.Schedule( + display_name=_TEST_PIPELINE_JOB_SCHEDULE_DISPLAY_NAME, + cron=_TEST_PIPELINE_JOB_SCHEDULE_CRON, + max_concurrent_run_count=_TEST_PIPELINE_JOB_SCHEDULE_MAX_CONCURRENT_RUN_COUNT, + max_run_count=_TEST_PIPELINE_JOB_SCHEDULE_MAX_RUN_COUNT, + create_pipeline_job_request={ + "parent": _TEST_PARENT, + "pipeline_job": { + "runtime_config": runtime_config, + "pipeline_spec": dict_to_struct(pipeline_spec), + "labels": TEST_PIPELINE_JOB_LABELS, + "service_account": _TEST_SERVICE_ACCOUNT, + "network": _TEST_NETWORK, + }, + }, + ) + + mock_schedule_service_create.assert_called_once_with( + parent=_TEST_PARENT, + schedule=expected_gapic_pipeline_job_schedule, + timeout=None, + ) + @pytest.mark.parametrize( "job_spec", [_TEST_PIPELINE_SPEC_JSON, _TEST_PIPELINE_SPEC_YAML, _TEST_PIPELINE_JOB],