Skip to content

Commit

Permalink
feat: Support PreflightValidation in Preview PipelineJob submit funct…
Browse files Browse the repository at this point in the history
…ion.

PiperOrigin-RevId: 628707894
  • Loading branch information
vertex-sdk-bot authored and Copybara-Service committed Apr 27, 2024
1 parent 1341e2c commit e88dc0d
Show file tree
Hide file tree
Showing 2 changed files with 358 additions and 7 deletions.
319 changes: 312 additions & 7 deletions google/cloud/aiplatform/preview/pipelinejob/pipeline_jobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,21 +15,64 @@
# limitations under the License.
#

from typing import List, Optional
import datetime
import re
from typing import Any, Dict, List, Optional

from google.auth import credentials as auth_credentials
from google.cloud import aiplatform_v1beta1
from google.cloud.aiplatform import base
from google.cloud.aiplatform import compat
from google.cloud.aiplatform import initializer
from google.cloud.aiplatform import pipeline_job_schedules
from google.cloud.aiplatform import utils
from google.cloud.aiplatform.constants import pipeline as pipeline_constants
from google.cloud.aiplatform.metadata import constants as metadata_constants
from google.cloud.aiplatform.metadata import experiment_resources
from google.cloud.aiplatform.pipeline_jobs import (
PipelineJob as PipelineJobGa,
)
from google.cloud.aiplatform_v1.services.pipeline_service import (
PipelineServiceClient as PipelineServiceClientGa,
)
from google.cloud import aiplatform_v1beta1
from google.cloud.aiplatform import compat, pipeline_job_schedules
from google.cloud.aiplatform import initializer
from google.cloud.aiplatform import utils

from google.cloud.aiplatform.metadata import constants as metadata_constants
from google.cloud.aiplatform.metadata import experiment_resources
from google.protobuf import json_format


_LOGGER = base.Logger(__name__)

# Pattern for valid names used as a Vertex resource name.
_VALID_NAME_PATTERN = pipeline_constants._VALID_NAME_PATTERN

# Pattern for an Artifact Registry URL.
_VALID_AR_URL = pipeline_constants._VALID_AR_URL

# Pattern for any JSON or YAML file over HTTPS.
_VALID_HTTPS_URL = pipeline_constants._VALID_HTTPS_URL


def _get_current_time() -> datetime.datetime:
"""Gets the current timestamp."""
return datetime.datetime.now()


def _set_enable_caching_value(
pipeline_spec: Dict[str, Any], enable_caching: bool
) -> None:
"""Sets pipeline tasks caching options.
Args:
pipeline_spec (Dict[str, Any]):
Required. The dictionary of pipeline spec.
enable_caching (bool):
Required. Whether to enable caching.
"""
for component in [pipeline_spec["root"]] + list(
pipeline_spec["components"].values()
):
if "dag" in component:
for task in component["dag"]["tasks"].values():
task["cachingOptions"] = {"enableCache": enable_caching}


class _PipelineJob(
Expand All @@ -42,6 +85,192 @@ class _PipelineJob(
):
"""Preview PipelineJob resource for Vertex AI."""

def __init__(
self,
display_name: str,
template_path: str,
job_id: Optional[str] = None,
pipeline_root: Optional[str] = None,
parameter_values: Optional[Dict[str, Any]] = None,
input_artifacts: Optional[Dict[str, str]] = None,
enable_caching: Optional[bool] = None,
encryption_spec_key_name: Optional[str] = None,
labels: Optional[Dict[str, str]] = None,
credentials: Optional[auth_credentials.Credentials] = None,
project: Optional[str] = None,
location: Optional[str] = None,
failure_policy: Optional[str] = None,
enable_preflight_validations: Optional[bool] = False,
):
"""Retrieves a PipelineJob resource and instantiates its
representation.
Args:
display_name (str):
Required. The user-defined name of this Pipeline.
template_path (str):
Required. The path of PipelineJob or PipelineSpec JSON or YAML file. It
can be a local path, a Google Cloud Storage URI (e.g. "gs://project.name"),
an Artifact Registry URI (e.g.
"https://us-central1-kfp.pkg.dev/proj/repo/pack/latest"), or an HTTPS URI.
job_id (str):
Optional. The unique ID of the job run.
If not specified, pipeline name + timestamp will be used.
pipeline_root (str):
Optional. The root of the pipeline outputs. If not set, the staging bucket
set in aiplatform.init will be used. If that's not set a pipeline-specific
artifacts bucket will be used.
parameter_values (Dict[str, Any]):
Optional. The mapping from runtime parameter names to its values that
control the pipeline run.
input_artifacts (Dict[str, str]):
Optional. The mapping from the runtime parameter name for this artifact to its resource id.
For example: "vertex_model":"456". Note: full resource name ("projects/123/locations/us-central1/metadataStores/default/artifacts/456") cannot be used.
enable_caching (bool):
Optional. Whether to turn on caching for the run.
If this is not set, defaults to the compile time settings, which
are True for all tasks by default, while users may specify
different caching options for individual tasks.
If this is set, the setting applies to all tasks in the pipeline.
Overrides the compile time settings.
encryption_spec_key_name (str):
Optional. The Cloud KMS resource identifier of the customer
managed encryption key used to protect the job. Has the
form:
``projects/my-project/locations/my-region/keyRings/my-kr/cryptoKeys/my-key``.
The key needs to be in the same region as where the compute
resource is created.
If this is set, then all
resources created by the PipelineJob will
be encrypted with the provided encryption key.
Overrides encryption_spec_key_name set in aiplatform.init.
labels (Dict[str, str]):
Optional. The user defined metadata to organize PipelineJob.
credentials (auth_credentials.Credentials):
Optional. Custom credentials to use to create this PipelineJob.
Overrides credentials set in aiplatform.init.
project (str):
Optional. The project that you want to run this PipelineJob in. If not set,
the project set in aiplatform.init will be used.
location (str):
Optional. Location to create PipelineJob. If not set,
location set in aiplatform.init will be used.
failure_policy (str):
Optional. The failure policy - "slow" or "fast".
Currently, the default of a pipeline is that the pipeline will continue to
run until no more tasks can be executed, also known as
PIPELINE_FAILURE_POLICY_FAIL_SLOW (corresponds to "slow").
However, if a pipeline is set to
PIPELINE_FAILURE_POLICY_FAIL_FAST (corresponds to "fast"),
it will stop scheduling any new tasks when a task has failed. Any
scheduled tasks will continue to completion.
enable_preflight_validations (bool):
Optional. Whether to enable preflight validations or not.
Raises:
ValueError: If job_id or labels have incorrect format.
"""

super().__init__(
display_name=display_name,
template_path=template_path,
job_id=job_id,
pipeline_root=pipeline_root,
parameter_values=parameter_values,
input_artifacts=input_artifacts,
enable_caching=enable_caching,
encryption_spec_key_name=encryption_spec_key_name,
labels=labels,
credentials=credentials,
project=project,
location=location,
failure_policy=failure_policy,
)

# needs to rebuild the v1beta version of pipeline_job and runtime_config
pipeline_json = utils.yaml_utils.load_yaml(
template_path, self.project, self.credentials
)

# Pipeline_json can be either PipelineJob or PipelineSpec.
if pipeline_json.get("pipelineSpec") is not None:
pipeline_job = pipeline_json
pipeline_root = (
pipeline_root
or pipeline_job["pipelineSpec"].get("defaultPipelineRoot")
or pipeline_job["runtimeConfig"].get("gcsOutputDirectory")
or initializer.global_config.staging_bucket
)
else:
pipeline_job = {
"pipelineSpec": pipeline_json,
"runtimeConfig": {},
}
pipeline_root = (
pipeline_root
or pipeline_job["pipelineSpec"].get("defaultPipelineRoot")
or initializer.global_config.staging_bucket
)
pipeline_root = (
pipeline_root
or utils.gcs_utils.generate_gcs_directory_for_pipeline_artifacts(
project=project,
location=location,
)
)
builder = utils.pipeline_utils.PipelineRuntimeConfigBuilder.from_job_spec_json(
pipeline_job
)
builder.update_pipeline_root(pipeline_root)
builder.update_runtime_parameters(parameter_values)
builder.update_input_artifacts(input_artifacts)

builder.update_failure_policy(failure_policy)
runtime_config_dict = builder.build()

runtime_config = aiplatform_v1beta1.PipelineJob.RuntimeConfig()._pb
json_format.ParseDict(runtime_config_dict, runtime_config)

pipeline_name = pipeline_job["pipelineSpec"]["pipelineInfo"]["name"]
self.job_id = job_id or "{pipeline_name}-{timestamp}".format(
pipeline_name=re.sub("[^-0-9a-z]+", "-", pipeline_name.lower())
.lstrip("-")
.rstrip("-"),
timestamp=_get_current_time().strftime("%Y%m%d%H%M%S"),
)
if not _VALID_NAME_PATTERN.match(self.job_id):
raise ValueError(
f"Generated job ID: {self.job_id} is illegal as a Vertex pipelines job ID. "
"Expecting an ID following the regex pattern "
f'"{_VALID_NAME_PATTERN.pattern[1:-1]}"'
)

if enable_caching is not None:
_set_enable_caching_value(pipeline_job["pipelineSpec"], enable_caching)

pipeline_job_args = {
"display_name": display_name,
"pipeline_spec": pipeline_job["pipelineSpec"],
"labels": labels,
"runtime_config": runtime_config,
"encryption_spec": initializer.global_config.get_encryption_spec(
encryption_spec_key_name=encryption_spec_key_name
),
"preflight_validations": enable_preflight_validations,
}

if _VALID_AR_URL.match(template_path) or _VALID_HTTPS_URL.match(template_path):
pipeline_job_args["template_uri"] = template_path

self._v1_beta1_pipeline_job = aiplatform_v1beta1.PipelineJob(
**pipeline_job_args
)

def create_schedule(
self,
cron_expression: str,
Expand Down Expand Up @@ -180,3 +409,79 @@ def batch_delete(
v1beta1_client = client.select_version(compat.V1BETA1)
operation = v1beta1_client.batch_delete_pipeline_jobs(request)
return operation.result()

def submit(
self,
service_account: Optional[str] = None,
network: Optional[str] = None,
reserved_ip_ranges: Optional[List[str]] = None,
create_request_timeout: Optional[float] = None,
job_id: Optional[str] = None,
) -> None:
"""Run this configured PipelineJob.
Args:
service_account (str):
Optional. Specifies the service account for workload run-as account.
Users submitting jobs must have act-as permission on this run-as account.
network (str):
Optional. The full name of the Compute Engine network to which the job
should be peered. For example, projects/12345/global/networks/myVPC.
Private services access must already be configured for the network.
If left unspecified, the network set in aiplatform.init will be used.
Otherwise, the job is not peered with any network.
reserved_ip_ranges (List[str]):
Optional. A list of names for the reserved IP ranges under the VPC
network that can be used for this PipelineJob's workload. For example: ['vertex-ai-ip-range'].
If left unspecified, the job will be deployed to any IP ranges under
the provided VPC network.
create_request_timeout (float):
Optional. The timeout for the create request in seconds.
job_id (str):
Optional. The ID to use for the PipelineJob, which will become the final
component of the PipelineJob name. If not provided, an ID will be
automatically generated.
"""
network = network or initializer.global_config.network
service_account = service_account or initializer.global_config.service_account
gca_resouce = self._v1_beta1_pipeline_job

if service_account:
gca_resouce.service_account = service_account

if network:
gca_resouce.network = network

if reserved_ip_ranges:
gca_resouce.reserved_ip_ranges = reserved_ip_ranges
user_project = initializer.global_config.project
user_location = initializer.global_config.location
parent = initializer.global_config.common_location_path(
project=user_project, location=user_location
)

client = self._instantiate_client(
location=user_location,
appended_user_agent=["preview-pipeline-job-submit"],
)
v1beta1_client = client.select_version(compat.V1BETA1)

_LOGGER.log_create_with_lro(self.__class__)

request = aiplatform_v1beta1.CreatePipelineJobRequest(
parent=parent,
pipeline_job=self._v1_beta1_pipeline_job,
pipeline_job_id=job_id or self.job_id,
)

response = v1beta1_client.create_pipeline_job(request=request)

self._gca_resource = response

_LOGGER.log_create_complete_with_getter(
self.__class__, self._gca_resource, "pipeline_job"
)

_LOGGER.info("View Pipeline Job:\n%s" % self._dashboard_uri())

0 comments on commit e88dc0d

Please sign in to comment.