Skip to content

Commit

Permalink
fix: Make PipelineJobSchedule propagate labels to created PipelineJobs
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 585812646
  • Loading branch information
vertex-sdk-bot authored and Copybara-Service committed Nov 28, 2023
1 parent 3f56ae7 commit a34533f
Show file tree
Hide file tree
Showing 2 changed files with 88 additions and 0 deletions.
4 changes: 4 additions & 0 deletions google/cloud/aiplatform/pipeline_job_schedules.py
Expand Up @@ -102,6 +102,10 @@ def __init__(
create_pipeline_job_request["pipeline_job"][
"template_uri"
] = pipeline_job._gca_resource.template_uri
if "labels" in pipeline_job._gca_resource:
create_pipeline_job_request["pipeline_job"][
"labels"
] = pipeline_job._gca_resource.labels
pipeline_job_schedule_args = {
"display_name": display_name,
"create_pipeline_job_request": create_pipeline_job_request,
Expand Down
84 changes: 84 additions & 0 deletions tests/unit/aiplatform/test_pipeline_job_schedules.py
Expand Up @@ -638,6 +638,90 @@ def test_call_schedule_service_create_uses_pipeline_job_project_location(
assert pipeline_job_schedule.project == "managed-pipeline-test"
assert pipeline_job_schedule.location == "europe-west4"

@pytest.mark.parametrize(
"job_spec",
[_TEST_PIPELINE_SPEC_JSON, _TEST_PIPELINE_SPEC_YAML, _TEST_PIPELINE_JOB],
)
def test_call_schedule_service_create_uses_pipeline_job_labels(
self,
mock_schedule_service_create,
mock_pipeline_service_list,
mock_schedule_service_get,
mock_schedule_bucket_exists,
job_spec,
mock_load_yaml_and_json,
):
"""Creates a PipelineJobSchedule.
Tests that PipelineJobs created through PipelineJobSchedule inherit the labels of the init PipelineJob.
"""
TEST_PIPELINE_JOB_LABELS = {"name": "test_xx"}

aiplatform.init(
project=_TEST_PROJECT,
staging_bucket=_TEST_GCS_BUCKET_NAME,
location=_TEST_LOCATION,
credentials=_TEST_CREDENTIALS,
)

job = pipeline_jobs.PipelineJob(
display_name=_TEST_PIPELINE_JOB_DISPLAY_NAME,
template_path=_TEST_TEMPLATE_PATH,
parameter_values=_TEST_PIPELINE_PARAMETER_VALUES,
input_artifacts=_TEST_PIPELINE_INPUT_ARTIFACTS,
enable_caching=True,
labels=TEST_PIPELINE_JOB_LABELS,
)

pipeline_job_schedule = pipeline_job_schedules.PipelineJobSchedule(
pipeline_job=job,
display_name=_TEST_PIPELINE_JOB_SCHEDULE_DISPLAY_NAME,
)

pipeline_job_schedule.create(
cron=_TEST_PIPELINE_JOB_SCHEDULE_CRON,
max_concurrent_run_count=_TEST_PIPELINE_JOB_SCHEDULE_MAX_CONCURRENT_RUN_COUNT,
max_run_count=_TEST_PIPELINE_JOB_SCHEDULE_MAX_RUN_COUNT,
service_account=_TEST_SERVICE_ACCOUNT,
network=_TEST_NETWORK,
create_request_timeout=None,
)

expected_runtime_config_dict = {
"gcsOutputDirectory": _TEST_GCS_BUCKET_NAME,
"parameterValues": _TEST_PIPELINE_PARAMETER_VALUES,
"inputArtifacts": {"vertex_model": {"artifactId": "456"}},
}
runtime_config = gca_pipeline_job.PipelineJob.RuntimeConfig()._pb
json_format.ParseDict(expected_runtime_config_dict, runtime_config)

job_spec = yaml.safe_load(job_spec)
pipeline_spec = job_spec.get("pipelineSpec") or job_spec

# Construct expected request
expected_gapic_pipeline_job_schedule = gca_schedule.Schedule(
display_name=_TEST_PIPELINE_JOB_SCHEDULE_DISPLAY_NAME,
cron=_TEST_PIPELINE_JOB_SCHEDULE_CRON,
max_concurrent_run_count=_TEST_PIPELINE_JOB_SCHEDULE_MAX_CONCURRENT_RUN_COUNT,
max_run_count=_TEST_PIPELINE_JOB_SCHEDULE_MAX_RUN_COUNT,
create_pipeline_job_request={
"parent": _TEST_PARENT,
"pipeline_job": {
"runtime_config": runtime_config,
"pipeline_spec": dict_to_struct(pipeline_spec),
"labels": TEST_PIPELINE_JOB_LABELS,
"service_account": _TEST_SERVICE_ACCOUNT,
"network": _TEST_NETWORK,
},
},
)

mock_schedule_service_create.assert_called_once_with(
parent=_TEST_PARENT,
schedule=expected_gapic_pipeline_job_schedule,
timeout=None,
)

@pytest.mark.parametrize(
"job_spec",
[_TEST_PIPELINE_SPEC_JSON, _TEST_PIPELINE_SPEC_YAML, _TEST_PIPELINE_JOB],
Expand Down

0 comments on commit a34533f

Please sign in to comment.