Skip to content

Commit

Permalink
feat: Support CMEK for scheduled pipeline jobs.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 590396834
  • Loading branch information
vertex-sdk-bot authored and Copybara-Service committed Dec 13, 2023
1 parent 6a00ed7 commit 406595d
Show file tree
Hide file tree
Showing 2 changed files with 91 additions and 0 deletions.
4 changes: 4 additions & 0 deletions google/cloud/aiplatform/pipeline_job_schedules.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,10 @@ def __init__(
create_pipeline_job_request["pipeline_job"][
"labels"
] = pipeline_job._gca_resource.labels
if "encryption_spec" in pipeline_job._gca_resource:
create_pipeline_job_request["pipeline_job"][
"encryption_spec"
] = pipeline_job._gca_resource.encryption_spec
pipeline_job_schedule_args = {
"display_name": display_name,
"create_pipeline_job_request": create_pipeline_job_request,
Expand Down
87 changes: 87 additions & 0 deletions tests/unit/aiplatform/test_pipeline_job_schedules.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
)
from google.cloud.aiplatform.compat.types import (
context as gca_context,
encryption_spec as gca_encryption_spec_compat,
pipeline_job as gca_pipeline_job,
pipeline_state as gca_pipeline_state,
schedule as gca_schedule,
Expand Down Expand Up @@ -722,6 +723,92 @@ def test_call_schedule_service_create_uses_pipeline_job_labels(
timeout=None,
)

@pytest.mark.parametrize(
"job_spec",
[_TEST_PIPELINE_SPEC_JSON, _TEST_PIPELINE_SPEC_YAML, _TEST_PIPELINE_JOB],
)
def test_call_schedule_service_create_uses_pipeline_job_encryption_spec_key_name(
self,
mock_schedule_service_create,
mock_pipeline_service_list,
mock_schedule_service_get,
mock_schedule_bucket_exists,
job_spec,
mock_load_yaml_and_json,
):
"""Creates a PipelineJobSchedule.
Tests that PipelineJobs created through PipelineJobSchedule inherit the encryption_spec_key_name of the init PipelineJob.
"""
TEST_PIPELINE_JOB_ENCRYPTION_SPEC_KEY_NAME = "encryption_spec_key_name"

aiplatform.init(
project=_TEST_PROJECT,
staging_bucket=_TEST_GCS_BUCKET_NAME,
location=_TEST_LOCATION,
credentials=_TEST_CREDENTIALS,
)

job = pipeline_jobs.PipelineJob(
display_name=_TEST_PIPELINE_JOB_DISPLAY_NAME,
template_path=_TEST_TEMPLATE_PATH,
parameter_values=_TEST_PIPELINE_PARAMETER_VALUES,
input_artifacts=_TEST_PIPELINE_INPUT_ARTIFACTS,
enable_caching=True,
encryption_spec_key_name=TEST_PIPELINE_JOB_ENCRYPTION_SPEC_KEY_NAME,
)

pipeline_job_schedule = pipeline_job_schedules.PipelineJobSchedule(
pipeline_job=job,
display_name=_TEST_PIPELINE_JOB_SCHEDULE_DISPLAY_NAME,
)

pipeline_job_schedule.create(
cron=_TEST_PIPELINE_JOB_SCHEDULE_CRON,
max_concurrent_run_count=_TEST_PIPELINE_JOB_SCHEDULE_MAX_CONCURRENT_RUN_COUNT,
max_run_count=_TEST_PIPELINE_JOB_SCHEDULE_MAX_RUN_COUNT,
service_account=_TEST_SERVICE_ACCOUNT,
network=_TEST_NETWORK,
create_request_timeout=None,
)

expected_runtime_config_dict = {
"gcsOutputDirectory": _TEST_GCS_BUCKET_NAME,
"parameterValues": _TEST_PIPELINE_PARAMETER_VALUES,
"inputArtifacts": {"vertex_model": {"artifactId": "456"}},
}
runtime_config = gca_pipeline_job.PipelineJob.RuntimeConfig()._pb
json_format.ParseDict(expected_runtime_config_dict, runtime_config)

job_spec = yaml.safe_load(job_spec)
pipeline_spec = job_spec.get("pipelineSpec") or job_spec

# Construct expected request
expected_gapic_pipeline_job_schedule = gca_schedule.Schedule(
display_name=_TEST_PIPELINE_JOB_SCHEDULE_DISPLAY_NAME,
cron=_TEST_PIPELINE_JOB_SCHEDULE_CRON,
max_concurrent_run_count=_TEST_PIPELINE_JOB_SCHEDULE_MAX_CONCURRENT_RUN_COUNT,
max_run_count=_TEST_PIPELINE_JOB_SCHEDULE_MAX_RUN_COUNT,
create_pipeline_job_request={
"parent": _TEST_PARENT,
"pipeline_job": {
"runtime_config": runtime_config,
"pipeline_spec": dict_to_struct(pipeline_spec),
"encryption_spec": gca_encryption_spec_compat.EncryptionSpec(
kms_key_name=TEST_PIPELINE_JOB_ENCRYPTION_SPEC_KEY_NAME
),
"service_account": _TEST_SERVICE_ACCOUNT,
"network": _TEST_NETWORK,
},
},
)

mock_schedule_service_create.assert_called_once_with(
parent=_TEST_PARENT,
schedule=expected_gapic_pipeline_job_schedule,
timeout=None,
)

@pytest.mark.parametrize(
"job_spec",
[_TEST_PIPELINE_SPEC_JSON, _TEST_PIPELINE_SPEC_YAML, _TEST_PIPELINE_JOB],
Expand Down

0 comments on commit 406595d

Please sign in to comment.