Skip to content

Commit

Permalink
feat: add a way to easily clone a PipelineJob (#1239)
Browse files Browse the repository at this point in the history
* Add batch_size kwarg for batch prediction jobs

* Fix errors

Update the copyright year. Change the order of the argument. Fix the syntax error.

* fix: change description layout

* feat: add clone method to PipelineJob

* fix: blacken and lint

* Update pipeline_jobs.py

* fix: update library names

* fix: formatting error
  • Loading branch information
jaycee-li committed Jun 3, 2022
1 parent b6bf6dc commit efaf6ed
Show file tree
Hide file tree
Showing 2 changed files with 313 additions and 6 deletions.
156 changes: 150 additions & 6 deletions google/cloud/aiplatform/pipeline_jobs.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# -*- coding: utf-8 -*-

# Copyright 2021 Google LLC
# Copyright 2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -144,15 +144,15 @@ def __init__(
be encrypted with the provided encryption key.
Overrides encryption_spec_key_name set in aiplatform.init.
labels (Dict[str,str]):
labels (Dict[str, str]):
Optional. The user defined metadata to organize PipelineJob.
credentials (auth_credentials.Credentials):
Optional. Custom credentials to use to create this PipelineJob.
Overrides credentials set in aiplatform.init.
project (str),
project (str):
Optional. The project that you want to run this PipelineJob in. If not set,
the project set in aiplatform.init will be used.
location (str),
location (str):
Optional. Location to create PipelineJob. If not set,
location set in aiplatform.init will be used.
Expand Down Expand Up @@ -215,9 +215,9 @@ def __init__(
)
if not _VALID_NAME_PATTERN.match(self.job_id):
raise ValueError(
"Generated job ID: {} is illegal as a Vertex pipelines job ID. "
f"Generated job ID: {self.job_id} is illegal as a Vertex pipelines job ID. "
"Expecting an ID following the regex pattern "
'"[a-z][-a-z0-9]{{0,127}}"'.format(job_id)
f'"{_VALID_NAME_PATTERN.pattern[1:-1]}"'
)

if enable_caching is not None:
Expand Down Expand Up @@ -471,3 +471,147 @@ def list(
def wait_for_resource_creation(self) -> None:
"""Waits until resource has been created."""
self._wait_for_resource_creation()

def clone(
self,
display_name: Optional[str] = None,
job_id: Optional[str] = None,
pipeline_root: Optional[str] = None,
parameter_values: Optional[Dict[str, Any]] = None,
enable_caching: Optional[bool] = None,
encryption_spec_key_name: Optional[str] = None,
labels: Optional[Dict[str, str]] = None,
credentials: Optional[auth_credentials.Credentials] = None,
project: Optional[str] = None,
location: Optional[str] = None,
) -> "PipelineJob":
"""Returns a new PipelineJob object with the same settings as the original one.
Args:
display_name (str):
Optional. The user-defined name of this cloned Pipeline.
If not specified, original pipeline display name will be used.
job_id (str):
Optional. The unique ID of the job run.
If not specified, "cloned" + pipeline name + timestamp will be used.
pipeline_root (str):
Optional. The root of the pipeline outputs. Default to be the same
staging bucket as original pipeline.
parameter_values (Dict[str, Any]):
Optional. The mapping from runtime parameter names to its values that
control the pipeline run. Defaults to be the same values as original
PipelineJob.
enable_caching (bool):
Optional. Whether to turn on caching for the run.
If this is not set, defaults to be the same as original pipeline.
If this is set, the setting applies to all tasks in the pipeline.
encryption_spec_key_name (str):
Optional. The Cloud KMS resource identifier of the customer
managed encryption key used to protect the job. Has the
form:
``projects/my-project/locations/my-region/keyRings/my-kr/cryptoKeys/my-key``.
The key needs to be in the same region as where the compute resource is created.
If this is set, then all
resources created by the PipelineJob will
be encrypted with the provided encryption key.
If not specified, encryption_spec of original PipelineJob will be used.
labels (Dict[str, str]):
Optional. The user defined metadata to organize PipelineJob.
credentials (auth_credentials.Credentials):
Optional. Custom credentials to use to create this PipelineJob.
Overrides credentials set in aiplatform.init.
project (str):
Optional. The project that you want to run this PipelineJob in.
If not set, the project set in original PipelineJob will be used.
location (str):
Optional. Location to create PipelineJob.
If not set, location set in original PipelineJob will be used.
Returns:
A Vertex AI PipelineJob.
Raises:
ValueError: If job_id or labels have incorrect format.
"""
## Initialize an empty PipelineJob
if not project:
project = self.project
if not location:
location = self.location
if not credentials:
credentials = self.credentials

cloned = self.__class__._empty_constructor(
project=project,
location=location,
credentials=credentials,
)
cloned._parent = initializer.global_config.common_location_path(
project=project, location=location
)

## Get gca_resource from original PipelineJob
pipeline_job = json_format.MessageToDict(self._gca_resource._pb)

## Set pipeline_spec
pipeline_spec = pipeline_job["pipelineSpec"]
if "deploymentConfig" in pipeline_spec:
del pipeline_spec["deploymentConfig"]

## Set caching
if enable_caching is not None:
_set_enable_caching_value(pipeline_spec, enable_caching)

## Set job_id
pipeline_name = pipeline_spec["pipelineInfo"]["name"]
cloned.job_id = job_id or "cloned-{pipeline_name}-{timestamp}".format(
pipeline_name=re.sub("[^-0-9a-z]+", "-", pipeline_name.lower())
.lstrip("-")
.rstrip("-"),
timestamp=_get_current_time().strftime("%Y%m%d%H%M%S"),
)
if not _VALID_NAME_PATTERN.match(cloned.job_id):
raise ValueError(
f"Generated job ID: {cloned.job_id} is illegal as a Vertex pipelines job ID. "
"Expecting an ID following the regex pattern "
f'"{_VALID_NAME_PATTERN.pattern[1:-1]}"'
)

## Set display_name, labels and encryption_spec
if display_name:
utils.validate_display_name(display_name)
elif not display_name and "displayName" in pipeline_job:
display_name = pipeline_job["displayName"]

if labels:
utils.validate_labels(labels)
elif not labels and "labels" in pipeline_job:
labels = pipeline_job["labels"]

if encryption_spec_key_name or "encryptionSpec" not in pipeline_job:
encryption_spec = initializer.global_config.get_encryption_spec(
encryption_spec_key_name=encryption_spec_key_name
)
else:
encryption_spec = pipeline_job["encryptionSpec"]

## Set runtime_config
builder = pipeline_utils.PipelineRuntimeConfigBuilder.from_job_spec_json(
pipeline_job
)
builder.update_pipeline_root(pipeline_root)
builder.update_runtime_parameters(parameter_values)
runtime_config_dict = builder.build()
runtime_config = gca_pipeline_job.PipelineJob.RuntimeConfig()._pb
json_format.ParseDict(runtime_config_dict, runtime_config)

## Create gca_resource for cloned PipelineJob
cloned._gca_resource = gca_pipeline_job.PipelineJob(
display_name=display_name,
pipeline_spec=pipeline_spec,
labels=labels,
runtime_config=runtime_config,
encryption_spec=encryption_spec,
)

return cloned
163 changes: 163 additions & 0 deletions tests/unit/aiplatform/test_pipeline_jobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -1038,3 +1038,166 @@ def test_pipeline_failure_raises(self, mock_load_yaml_and_json, sync):

if not sync:
job.wait()

@pytest.mark.parametrize(
"job_spec",
[_TEST_PIPELINE_SPEC_JSON, _TEST_PIPELINE_SPEC_YAML, _TEST_PIPELINE_JOB],
)
def test_clone_pipeline_job(
self,
mock_pipeline_service_create,
mock_pipeline_service_get,
job_spec,
mock_load_yaml_and_json,
):
aiplatform.init(
project=_TEST_PROJECT,
staging_bucket=_TEST_GCS_BUCKET_NAME,
location=_TEST_LOCATION,
credentials=_TEST_CREDENTIALS,
)

job = pipeline_jobs.PipelineJob(
display_name=_TEST_PIPELINE_JOB_DISPLAY_NAME,
template_path=_TEST_TEMPLATE_PATH,
job_id=_TEST_PIPELINE_JOB_ID,
parameter_values=_TEST_PIPELINE_PARAMETER_VALUES,
enable_caching=True,
)

cloned = job.clone(job_id=f"cloned-{_TEST_PIPELINE_JOB_ID}")

cloned.submit(
service_account=_TEST_SERVICE_ACCOUNT,
network=_TEST_NETWORK,
create_request_timeout=None,
)

expected_runtime_config_dict = {
"gcsOutputDirectory": _TEST_GCS_BUCKET_NAME,
"parameterValues": _TEST_PIPELINE_PARAMETER_VALUES,
}
runtime_config = gca_pipeline_job.PipelineJob.RuntimeConfig()._pb
json_format.ParseDict(expected_runtime_config_dict, runtime_config)

job_spec = yaml.safe_load(job_spec)
pipeline_spec = job_spec.get("pipelineSpec") or job_spec

# Construct expected request
expected_gapic_pipeline_job = gca_pipeline_job.PipelineJob(
display_name=_TEST_PIPELINE_JOB_DISPLAY_NAME,
pipeline_spec={
"components": {},
"pipelineInfo": pipeline_spec["pipelineInfo"],
"root": pipeline_spec["root"],
"schemaVersion": "2.1.0",
},
runtime_config=runtime_config,
service_account=_TEST_SERVICE_ACCOUNT,
network=_TEST_NETWORK,
)

mock_pipeline_service_create.assert_called_once_with(
parent=_TEST_PARENT,
pipeline_job=expected_gapic_pipeline_job,
pipeline_job_id=f"cloned-{_TEST_PIPELINE_JOB_ID}",
timeout=None,
)

assert not mock_pipeline_service_get.called

cloned.wait()

mock_pipeline_service_get.assert_called_with(
name=_TEST_PIPELINE_JOB_NAME, retry=base._DEFAULT_RETRY
)

assert cloned._gca_resource == make_pipeline_job(
gca_pipeline_state.PipelineState.PIPELINE_STATE_SUCCEEDED
)

@pytest.mark.parametrize(
"job_spec",
[_TEST_PIPELINE_SPEC_JSON, _TEST_PIPELINE_SPEC_YAML, _TEST_PIPELINE_JOB],
)
def test_clone_pipeline_job_with_all_args(
self,
mock_pipeline_service_create,
mock_pipeline_service_get,
job_spec,
mock_load_yaml_and_json,
):
aiplatform.init(
project=_TEST_PROJECT,
staging_bucket=_TEST_GCS_BUCKET_NAME,
location=_TEST_LOCATION,
credentials=_TEST_CREDENTIALS,
)

job = pipeline_jobs.PipelineJob(
display_name=_TEST_PIPELINE_JOB_DISPLAY_NAME,
template_path=_TEST_TEMPLATE_PATH,
job_id=_TEST_PIPELINE_JOB_ID,
parameter_values=_TEST_PIPELINE_PARAMETER_VALUES,
enable_caching=True,
)

cloned = job.clone(
display_name=f"cloned-{_TEST_PIPELINE_JOB_DISPLAY_NAME}",
job_id=f"cloned-{_TEST_PIPELINE_JOB_ID}",
pipeline_root=f"cloned-{_TEST_GCS_BUCKET_NAME}",
parameter_values=_TEST_PIPELINE_PARAMETER_VALUES,
enable_caching=True,
credentials=_TEST_CREDENTIALS,
project=_TEST_PROJECT,
location=_TEST_LOCATION,
)

cloned.submit(
service_account=_TEST_SERVICE_ACCOUNT,
network=_TEST_NETWORK,
create_request_timeout=None,
)

expected_runtime_config_dict = {
"gcsOutputDirectory": f"cloned-{_TEST_GCS_BUCKET_NAME}",
"parameterValues": _TEST_PIPELINE_PARAMETER_VALUES,
}
runtime_config = gca_pipeline_job.PipelineJob.RuntimeConfig()._pb
json_format.ParseDict(expected_runtime_config_dict, runtime_config)

job_spec = yaml.safe_load(job_spec)
pipeline_spec = job_spec.get("pipelineSpec") or job_spec

# Construct expected request
expected_gapic_pipeline_job = gca_pipeline_job.PipelineJob(
display_name=f"cloned-{_TEST_PIPELINE_JOB_DISPLAY_NAME}",
pipeline_spec={
"components": {},
"pipelineInfo": pipeline_spec["pipelineInfo"],
"root": pipeline_spec["root"],
"schemaVersion": "2.1.0",
},
runtime_config=runtime_config,
service_account=_TEST_SERVICE_ACCOUNT,
network=_TEST_NETWORK,
)

mock_pipeline_service_create.assert_called_once_with(
parent=_TEST_PARENT,
pipeline_job=expected_gapic_pipeline_job,
pipeline_job_id=f"cloned-{_TEST_PIPELINE_JOB_ID}",
timeout=None,
)

assert not mock_pipeline_service_get.called

cloned.wait()

mock_pipeline_service_get.assert_called_with(
name=_TEST_PIPELINE_JOB_NAME, retry=base._DEFAULT_RETRY
)

assert cloned._gca_resource == make_pipeline_job(
gca_pipeline_state.PipelineState.PIPELINE_STATE_SUCCEEDED
)

0 comments on commit efaf6ed

Please sign in to comment.