diff --git a/google/cloud/aiplatform/pipeline_jobs.py b/google/cloud/aiplatform/pipeline_jobs.py index 90d7e0f86d..bc50a47aa2 100644 --- a/google/cloud/aiplatform/pipeline_jobs.py +++ b/google/cloud/aiplatform/pipeline_jobs.py @@ -1,6 +1,6 @@ # -*- coding: utf-8 -*- -# Copyright 2021 Google LLC +# Copyright 2022 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -144,15 +144,15 @@ def __init__( be encrypted with the provided encryption key. Overrides encryption_spec_key_name set in aiplatform.init. - labels (Dict[str,str]): + labels (Dict[str, str]): Optional. The user defined metadata to organize PipelineJob. credentials (auth_credentials.Credentials): Optional. Custom credentials to use to create this PipelineJob. Overrides credentials set in aiplatform.init. - project (str), + project (str): Optional. The project that you want to run this PipelineJob in. If not set, the project set in aiplatform.init will be used. - location (str), + location (str): Optional. Location to create PipelineJob. If not set, location set in aiplatform.init will be used. @@ -215,9 +215,9 @@ def __init__( ) if not _VALID_NAME_PATTERN.match(self.job_id): raise ValueError( - "Generated job ID: {} is illegal as a Vertex pipelines job ID. " + f"Generated job ID: {self.job_id} is illegal as a Vertex pipelines job ID. " "Expecting an ID following the regex pattern " - '"[a-z][-a-z0-9]{{0,127}}"'.format(job_id) + f'"{_VALID_NAME_PATTERN.pattern[1:-1]}"' ) if enable_caching is not None: @@ -471,3 +471,147 @@ def list( def wait_for_resource_creation(self) -> None: """Waits until resource has been created.""" self._wait_for_resource_creation() + + def clone( + self, + display_name: Optional[str] = None, + job_id: Optional[str] = None, + pipeline_root: Optional[str] = None, + parameter_values: Optional[Dict[str, Any]] = None, + enable_caching: Optional[bool] = None, + encryption_spec_key_name: Optional[str] = None, + labels: Optional[Dict[str, str]] = None, + credentials: Optional[auth_credentials.Credentials] = None, + project: Optional[str] = None, + location: Optional[str] = None, + ) -> "PipelineJob": + """Returns a new PipelineJob object with the same settings as the original one. + + Args: + display_name (str): + Optional. The user-defined name of this cloned Pipeline. + If not specified, original pipeline display name will be used. + job_id (str): + Optional. The unique ID of the job run. + If not specified, "cloned" + pipeline name + timestamp will be used. + pipeline_root (str): + Optional. The root of the pipeline outputs. Default to be the same + staging bucket as original pipeline. + parameter_values (Dict[str, Any]): + Optional. The mapping from runtime parameter names to its values that + control the pipeline run. Defaults to be the same values as original + PipelineJob. + enable_caching (bool): + Optional. Whether to turn on caching for the run. + If this is not set, defaults to be the same as original pipeline. + If this is set, the setting applies to all tasks in the pipeline. + encryption_spec_key_name (str): + Optional. The Cloud KMS resource identifier of the customer + managed encryption key used to protect the job. Has the + form: + ``projects/my-project/locations/my-region/keyRings/my-kr/cryptoKeys/my-key``. + The key needs to be in the same region as where the compute resource is created. + If this is set, then all + resources created by the PipelineJob will + be encrypted with the provided encryption key. + If not specified, encryption_spec of original PipelineJob will be used. + labels (Dict[str, str]): + Optional. The user defined metadata to organize PipelineJob. + credentials (auth_credentials.Credentials): + Optional. Custom credentials to use to create this PipelineJob. + Overrides credentials set in aiplatform.init. + project (str): + Optional. The project that you want to run this PipelineJob in. + If not set, the project set in original PipelineJob will be used. + location (str): + Optional. Location to create PipelineJob. + If not set, location set in original PipelineJob will be used. + + Returns: + A Vertex AI PipelineJob. + + Raises: + ValueError: If job_id or labels have incorrect format. + """ + ## Initialize an empty PipelineJob + if not project: + project = self.project + if not location: + location = self.location + if not credentials: + credentials = self.credentials + + cloned = self.__class__._empty_constructor( + project=project, + location=location, + credentials=credentials, + ) + cloned._parent = initializer.global_config.common_location_path( + project=project, location=location + ) + + ## Get gca_resource from original PipelineJob + pipeline_job = json_format.MessageToDict(self._gca_resource._pb) + + ## Set pipeline_spec + pipeline_spec = pipeline_job["pipelineSpec"] + if "deploymentConfig" in pipeline_spec: + del pipeline_spec["deploymentConfig"] + + ## Set caching + if enable_caching is not None: + _set_enable_caching_value(pipeline_spec, enable_caching) + + ## Set job_id + pipeline_name = pipeline_spec["pipelineInfo"]["name"] + cloned.job_id = job_id or "cloned-{pipeline_name}-{timestamp}".format( + pipeline_name=re.sub("[^-0-9a-z]+", "-", pipeline_name.lower()) + .lstrip("-") + .rstrip("-"), + timestamp=_get_current_time().strftime("%Y%m%d%H%M%S"), + ) + if not _VALID_NAME_PATTERN.match(cloned.job_id): + raise ValueError( + f"Generated job ID: {cloned.job_id} is illegal as a Vertex pipelines job ID. " + "Expecting an ID following the regex pattern " + f'"{_VALID_NAME_PATTERN.pattern[1:-1]}"' + ) + + ## Set display_name, labels and encryption_spec + if display_name: + utils.validate_display_name(display_name) + elif not display_name and "displayName" in pipeline_job: + display_name = pipeline_job["displayName"] + + if labels: + utils.validate_labels(labels) + elif not labels and "labels" in pipeline_job: + labels = pipeline_job["labels"] + + if encryption_spec_key_name or "encryptionSpec" not in pipeline_job: + encryption_spec = initializer.global_config.get_encryption_spec( + encryption_spec_key_name=encryption_spec_key_name + ) + else: + encryption_spec = pipeline_job["encryptionSpec"] + + ## Set runtime_config + builder = pipeline_utils.PipelineRuntimeConfigBuilder.from_job_spec_json( + pipeline_job + ) + builder.update_pipeline_root(pipeline_root) + builder.update_runtime_parameters(parameter_values) + runtime_config_dict = builder.build() + runtime_config = gca_pipeline_job.PipelineJob.RuntimeConfig()._pb + json_format.ParseDict(runtime_config_dict, runtime_config) + + ## Create gca_resource for cloned PipelineJob + cloned._gca_resource = gca_pipeline_job.PipelineJob( + display_name=display_name, + pipeline_spec=pipeline_spec, + labels=labels, + runtime_config=runtime_config, + encryption_spec=encryption_spec, + ) + + return cloned diff --git a/tests/unit/aiplatform/test_pipeline_jobs.py b/tests/unit/aiplatform/test_pipeline_jobs.py index 159400f8ce..1f6f2bb50c 100644 --- a/tests/unit/aiplatform/test_pipeline_jobs.py +++ b/tests/unit/aiplatform/test_pipeline_jobs.py @@ -1038,3 +1038,166 @@ def test_pipeline_failure_raises(self, mock_load_yaml_and_json, sync): if not sync: job.wait() + + @pytest.mark.parametrize( + "job_spec", + [_TEST_PIPELINE_SPEC_JSON, _TEST_PIPELINE_SPEC_YAML, _TEST_PIPELINE_JOB], + ) + def test_clone_pipeline_job( + self, + mock_pipeline_service_create, + mock_pipeline_service_get, + job_spec, + mock_load_yaml_and_json, + ): + aiplatform.init( + project=_TEST_PROJECT, + staging_bucket=_TEST_GCS_BUCKET_NAME, + location=_TEST_LOCATION, + credentials=_TEST_CREDENTIALS, + ) + + job = pipeline_jobs.PipelineJob( + display_name=_TEST_PIPELINE_JOB_DISPLAY_NAME, + template_path=_TEST_TEMPLATE_PATH, + job_id=_TEST_PIPELINE_JOB_ID, + parameter_values=_TEST_PIPELINE_PARAMETER_VALUES, + enable_caching=True, + ) + + cloned = job.clone(job_id=f"cloned-{_TEST_PIPELINE_JOB_ID}") + + cloned.submit( + service_account=_TEST_SERVICE_ACCOUNT, + network=_TEST_NETWORK, + create_request_timeout=None, + ) + + expected_runtime_config_dict = { + "gcsOutputDirectory": _TEST_GCS_BUCKET_NAME, + "parameterValues": _TEST_PIPELINE_PARAMETER_VALUES, + } + runtime_config = gca_pipeline_job.PipelineJob.RuntimeConfig()._pb + json_format.ParseDict(expected_runtime_config_dict, runtime_config) + + job_spec = yaml.safe_load(job_spec) + pipeline_spec = job_spec.get("pipelineSpec") or job_spec + + # Construct expected request + expected_gapic_pipeline_job = gca_pipeline_job.PipelineJob( + display_name=_TEST_PIPELINE_JOB_DISPLAY_NAME, + pipeline_spec={ + "components": {}, + "pipelineInfo": pipeline_spec["pipelineInfo"], + "root": pipeline_spec["root"], + "schemaVersion": "2.1.0", + }, + runtime_config=runtime_config, + service_account=_TEST_SERVICE_ACCOUNT, + network=_TEST_NETWORK, + ) + + mock_pipeline_service_create.assert_called_once_with( + parent=_TEST_PARENT, + pipeline_job=expected_gapic_pipeline_job, + pipeline_job_id=f"cloned-{_TEST_PIPELINE_JOB_ID}", + timeout=None, + ) + + assert not mock_pipeline_service_get.called + + cloned.wait() + + mock_pipeline_service_get.assert_called_with( + name=_TEST_PIPELINE_JOB_NAME, retry=base._DEFAULT_RETRY + ) + + assert cloned._gca_resource == make_pipeline_job( + gca_pipeline_state.PipelineState.PIPELINE_STATE_SUCCEEDED + ) + + @pytest.mark.parametrize( + "job_spec", + [_TEST_PIPELINE_SPEC_JSON, _TEST_PIPELINE_SPEC_YAML, _TEST_PIPELINE_JOB], + ) + def test_clone_pipeline_job_with_all_args( + self, + mock_pipeline_service_create, + mock_pipeline_service_get, + job_spec, + mock_load_yaml_and_json, + ): + aiplatform.init( + project=_TEST_PROJECT, + staging_bucket=_TEST_GCS_BUCKET_NAME, + location=_TEST_LOCATION, + credentials=_TEST_CREDENTIALS, + ) + + job = pipeline_jobs.PipelineJob( + display_name=_TEST_PIPELINE_JOB_DISPLAY_NAME, + template_path=_TEST_TEMPLATE_PATH, + job_id=_TEST_PIPELINE_JOB_ID, + parameter_values=_TEST_PIPELINE_PARAMETER_VALUES, + enable_caching=True, + ) + + cloned = job.clone( + display_name=f"cloned-{_TEST_PIPELINE_JOB_DISPLAY_NAME}", + job_id=f"cloned-{_TEST_PIPELINE_JOB_ID}", + pipeline_root=f"cloned-{_TEST_GCS_BUCKET_NAME}", + parameter_values=_TEST_PIPELINE_PARAMETER_VALUES, + enable_caching=True, + credentials=_TEST_CREDENTIALS, + project=_TEST_PROJECT, + location=_TEST_LOCATION, + ) + + cloned.submit( + service_account=_TEST_SERVICE_ACCOUNT, + network=_TEST_NETWORK, + create_request_timeout=None, + ) + + expected_runtime_config_dict = { + "gcsOutputDirectory": f"cloned-{_TEST_GCS_BUCKET_NAME}", + "parameterValues": _TEST_PIPELINE_PARAMETER_VALUES, + } + runtime_config = gca_pipeline_job.PipelineJob.RuntimeConfig()._pb + json_format.ParseDict(expected_runtime_config_dict, runtime_config) + + job_spec = yaml.safe_load(job_spec) + pipeline_spec = job_spec.get("pipelineSpec") or job_spec + + # Construct expected request + expected_gapic_pipeline_job = gca_pipeline_job.PipelineJob( + display_name=f"cloned-{_TEST_PIPELINE_JOB_DISPLAY_NAME}", + pipeline_spec={ + "components": {}, + "pipelineInfo": pipeline_spec["pipelineInfo"], + "root": pipeline_spec["root"], + "schemaVersion": "2.1.0", + }, + runtime_config=runtime_config, + service_account=_TEST_SERVICE_ACCOUNT, + network=_TEST_NETWORK, + ) + + mock_pipeline_service_create.assert_called_once_with( + parent=_TEST_PARENT, + pipeline_job=expected_gapic_pipeline_job, + pipeline_job_id=f"cloned-{_TEST_PIPELINE_JOB_ID}", + timeout=None, + ) + + assert not mock_pipeline_service_get.called + + cloned.wait() + + mock_pipeline_service_get.assert_called_with( + name=_TEST_PIPELINE_JOB_NAME, retry=base._DEFAULT_RETRY + ) + + assert cloned._gca_resource == make_pipeline_job( + gca_pipeline_state.PipelineState.PIPELINE_STATE_SUCCEEDED + )