diff --git a/google/cloud/aiplatform/compat/__init__.py b/google/cloud/aiplatform/compat/__init__.py index 7538387da9..61a6f9ad1d 100644 --- a/google/cloud/aiplatform/compat/__init__.py +++ b/google/cloud/aiplatform/compat/__init__.py @@ -39,6 +39,7 @@ services.model_garden_service_client = services.model_garden_service_client_v1beta1 services.pipeline_service_client = services.pipeline_service_client_v1beta1 services.prediction_service_client = services.prediction_service_client_v1beta1 + services.schedule_service_client = services.schedule_service_client_v1beta1 services.specialist_pool_service_client = ( services.specialist_pool_service_client_v1beta1 ) @@ -114,6 +115,8 @@ types.pipeline_state = types.pipeline_state_v1beta1 types.prediction_service = types.prediction_service_v1beta1 types.publisher_model = types.publisher_model_v1beta1 + types.schedule = types.schedule_v1beta1 + types.schedule_service = types.schedule_service_v1beta1 types.specialist_pool = types.specialist_pool_v1beta1 types.specialist_pool_service = types.specialist_pool_service_v1beta1 types.study = types.study_v1beta1 diff --git a/google/cloud/aiplatform/compat/services/__init__.py b/google/cloud/aiplatform/compat/services/__init__.py index 4dc073c3b2..5edd4fdd3b 100644 --- a/google/cloud/aiplatform/compat/services/__init__.py +++ b/google/cloud/aiplatform/compat/services/__init__.py @@ -57,6 +57,9 @@ from google.cloud.aiplatform_v1beta1.services.prediction_service import ( client as prediction_service_client_v1beta1, ) +from google.cloud.aiplatform_v1beta1.services.schedule_service import ( + client as schedule_service_client_v1beta1, +) from google.cloud.aiplatform_v1beta1.services.specialist_pool_service import ( client as specialist_pool_service_client_v1beta1, ) @@ -140,6 +143,7 @@ model_service_client_v1beta1, pipeline_service_client_v1beta1, prediction_service_client_v1beta1, + schedule_service_client_v1beta1, specialist_pool_service_client_v1beta1, metadata_service_client_v1beta1, tensorboard_service_client_v1beta1, diff --git a/google/cloud/aiplatform/compat/types/__init__.py b/google/cloud/aiplatform/compat/types/__init__.py index dcc0fa4e62..41f2123246 100644 --- a/google/cloud/aiplatform/compat/types/__init__.py +++ b/google/cloud/aiplatform/compat/types/__init__.py @@ -75,6 +75,8 @@ pipeline_state as pipeline_state_v1beta1, prediction_service as prediction_service_v1beta1, publisher_model as publisher_model_v1beta1, + schedule as schedule_v1beta1, + schedule_service as schedule_service_v1beta1, specialist_pool as specialist_pool_v1beta1, specialist_pool_service as specialist_pool_service_v1beta1, study as study_v1beta1, @@ -283,6 +285,8 @@ pipeline_state_v1beta1, prediction_service_v1beta1, publisher_model_v1beta1, + schedule_v1beta1, + schedule_service_v1beta1, specialist_pool_v1beta1, specialist_pool_service_v1beta1, study_v1beta1, diff --git a/google/cloud/aiplatform/pipeline_jobs.py b/google/cloud/aiplatform/pipeline_jobs.py index 19de1fc88f..42c65ca6d4 100644 --- a/google/cloud/aiplatform/pipeline_jobs.py +++ b/google/cloud/aiplatform/pipeline_jobs.py @@ -17,9 +17,9 @@ import datetime import logging -import time import re import tempfile +import time from typing import Any, Callable, Dict, List, Optional, Union from google.auth import credentials as auth_credentials @@ -29,16 +29,16 @@ from google.cloud.aiplatform import utils from google.cloud.aiplatform.constants import pipeline as pipeline_constants from google.cloud.aiplatform.metadata import artifact +from google.cloud.aiplatform.metadata import constants as metadata_constants from google.cloud.aiplatform.metadata import context from google.cloud.aiplatform.metadata import execution -from google.cloud.aiplatform.metadata import constants as metadata_constants from google.cloud.aiplatform.metadata import experiment_resources from google.cloud.aiplatform.metadata import utils as metadata_utils from google.cloud.aiplatform.utils import gcs_utils -from google.cloud.aiplatform.utils import yaml_utils from google.cloud.aiplatform.utils import pipeline_utils -from google.protobuf import json_format +from google.cloud.aiplatform.utils import yaml_utils from google.protobuf import field_mask_pb2 as field_mask +from google.protobuf import json_format from google.cloud.aiplatform.compat.types import ( pipeline_job as gca_pipeline_job, @@ -96,7 +96,6 @@ class PipelineJob( ), ), ): - client_class = utils.PipelineJobClientWithOverride _resource_noun = "pipelineJobs" _delete_method = "delete_pipeline_job" @@ -443,6 +442,10 @@ def wait(self): def pipeline_spec(self): return self._gca_resource.pipeline_spec + @property + def runtime_config(self) -> gca_pipeline_job.PipelineJob.RuntimeConfig: + return self._gca_resource.runtime_config + @property def state(self) -> Optional[gca_pipeline_state.PipelineState]: """Current pipeline state.""" diff --git a/google/cloud/aiplatform/preview/constants/schedules.py b/google/cloud/aiplatform/preview/constants/schedules.py new file mode 100644 index 0000000000..ee053985d4 --- /dev/null +++ b/google/cloud/aiplatform/preview/constants/schedules.py @@ -0,0 +1,58 @@ +# -*- coding: utf-8 -*- + +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from google.cloud.aiplatform.compat.types import ( + schedule_v1beta1 as gca_schedule, +) +from google.cloud.aiplatform.constants import pipeline as pipeline_constants + +_SCHEDULE_COMPLETE_STATES = set( + [ + gca_schedule.Schedule.State.PAUSED, + gca_schedule.Schedule.State.COMPLETED, + ] +) + +_SCHEDULE_ERROR_STATES = set( + [ + gca_schedule.Schedule.State.STATE_UNSPECIFIED, + ] +) + +# Pattern for valid names used as a Vertex resource name. +_VALID_NAME_PATTERN = pipeline_constants._VALID_NAME_PATTERN + +# Pattern for an Artifact Registry URL. +_VALID_AR_URL = pipeline_constants._VALID_AR_URL + +# Pattern for any JSON or YAML file over HTTPS. +_VALID_HTTPS_URL = pipeline_constants._VALID_HTTPS_URL + +# Fields to include in returned PipelineJobSchedule when enable_simple_view=True in PipelineJobSchedule.list() +_PIPELINE_JOB_SCHEDULE_READ_MASK_FIELDS = [ + "name", + "display_name", + "start_time", + "end_time", + "max_run_count", + "started_run_count", + "state", + "create_time", + "update_time", + "cron", + "catch_up", +] diff --git a/google/cloud/aiplatform/preview/pipelinejobschedule/pipeline_job_schedules.py b/google/cloud/aiplatform/preview/pipelinejobschedule/pipeline_job_schedules.py new file mode 100644 index 0000000000..8520646b9f --- /dev/null +++ b/google/cloud/aiplatform/preview/pipelinejobschedule/pipeline_job_schedules.py @@ -0,0 +1,257 @@ +# -*- coding: utf-8 -*- + +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from typing import Optional + +from google.auth import credentials as auth_credentials +from google.cloud.aiplatform import base +from google.cloud.aiplatform import initializer +from google.cloud.aiplatform import ( + PipelineJob, +) +from google.cloud.aiplatform import utils +from google.cloud.aiplatform.compat.types import ( + schedule_v1beta1 as gca_schedule, +) +from google.cloud.aiplatform.preview.constants import ( + schedules as schedule_constants, +) +from google.cloud.aiplatform.preview.schedule.schedules import _Schedule + +# TODO(b/283318141): Remove imports once PipelineJobSchedule is GA. +from google.cloud.aiplatform_v1.types import ( + pipeline_job as gca_pipeline_job_v1, +) +from google.cloud.aiplatform_v1beta1.types import ( + pipeline_job as gca_pipeline_job_v1beta1, +) + + +_LOGGER = base.Logger(__name__) + +# Pattern for valid names used as a Vertex resource name. +_VALID_NAME_PATTERN = schedule_constants._VALID_NAME_PATTERN + +# Pattern for an Artifact Registry URL. +_VALID_AR_URL = schedule_constants._VALID_AR_URL + +# Pattern for any JSON or YAML file over HTTPS. +_VALID_HTTPS_URL = schedule_constants._VALID_HTTPS_URL + +_READ_MASK_FIELDS = schedule_constants._PIPELINE_JOB_SCHEDULE_READ_MASK_FIELDS + + +class PipelineJobSchedule( + _Schedule, +): + def __init__( + self, + pipeline_job: PipelineJob, + display_name: str, + credentials: Optional[auth_credentials.Credentials] = None, + project: Optional[str] = None, + location: Optional[str] = None, + ): + """Retrieves a PipelineJobSchedule resource and instantiates its + representation. + Args: + pipeline_job (PipelineJob): + Required. PipelineJob used to init the schedule. + display_name (str): + Required. The user-defined name of this PipelineJobSchedule. + credentials (auth_credentials.Credentials): + Optional. Custom credentials to use to create this PipelineJobSchedule. + Overrides credentials set in aiplatform.init. + project (str): + Optional. The project that you want to run this PipelineJobSchedule in. + If not set, the project set in aiplatform.init will be used. + location (str): + Optional. Location to create PipelineJobSchedule. If not set, + location set in aiplatform.init will be used. + """ + if not display_name: + display_name = self.__class__._generate_display_name() + utils.validate_display_name(display_name) + + super().__init__(credentials=credentials, project=project, location=location) + + self._parent = initializer.global_config.common_location_path( + project=project, location=location + ) + + # TODO(b/283318141): Remove temporary logic once PipelineJobSchedule is GA. + runtime_config = gca_pipeline_job_v1beta1.PipelineJob.RuntimeConfig.deserialize( + gca_pipeline_job_v1.PipelineJob.RuntimeConfig.serialize( + pipeline_job.runtime_config + ) + ) + create_pipeline_job_request = { + "parent": self._parent, + "pipeline_job": { + "runtime_config": runtime_config, + "pipeline_spec": {"fields": pipeline_job.pipeline_spec}, + }, + } + pipeline_job_schedule_args = { + "display_name": display_name, + "create_pipeline_job_request": create_pipeline_job_request, + } + + self._gca_resource = gca_schedule.Schedule(**pipeline_job_schedule_args) + + def create( + self, + cron_expression: str, + start_time: Optional[str] = None, + end_time: Optional[str] = None, + allow_queueing: bool = False, + max_run_count: Optional[int] = None, + max_concurrent_run_count: int = 1, + service_account: Optional[str] = None, + network: Optional[str] = None, + create_request_timeout: Optional[float] = None, + ) -> None: + """Create a PipelineJobSchedule. + + Args: + cron_expression (str): + Required. Time specification (cron schedule expression) to launch scheduled runs. + To explicitly set a timezone to the cron tab, apply a prefix: "CRON_TZ=${IANA_TIME_ZONE}" or "TZ=${IANA_TIME_ZONE}". + The ${IANA_TIME_ZONE} may only be a valid string from IANA time zone database. + For example, "CRON_TZ=America/New_York 1 * * * *", or "TZ=America/New_York 1 * * * *". + start_time (str): + Optional. Timestamp after which the first run can be scheduled. + If unspecified, it defaults to the schedule creation timestamp. + end_time (str): + Optional. Timestamp after which no more runs will be scheduled. + If unspecified, then runs will be scheduled indefinitely. + allow_queueing (bool): + Optional. Whether new scheduled runs can be queued when max_concurrent_runs limit is reached. + max_run_count (int): + Optional. Maximum run count of the schedule. + If specified, The schedule will be completed when either started_run_count >= max_run_count or when end_time is reached. + max_concurrent_run_count (int): + Optional. Maximum number of runs that can be started concurrently for this PipelineJobSchedule. + service_account (str): + Optional. Specifies the service account for workload run-as account. + Users submitting jobs must have act-as permission on this run-as account. + network (str): + Optional. The full name of the Compute Engine network to which the job + should be peered. For example, projects/12345/global/networks/myVPC. + Private services access must already be configured for the network. + If left unspecified, the network set in aiplatform.init will be used. + Otherwise, the job is not peered with any network. + create_request_timeout (float): + Optional. The timeout for the create request in seconds. + """ + network = network or initializer.global_config.network + + self._create( + cron_expression=cron_expression, + start_time=start_time, + end_time=end_time, + allow_queueing=allow_queueing, + max_run_count=max_run_count, + max_concurrent_run_count=max_concurrent_run_count, + service_account=service_account, + network=network, + create_request_timeout=create_request_timeout, + ) + + def _create( + self, + cron_expression: str, + start_time: Optional[str] = None, + end_time: Optional[str] = None, + allow_queueing: bool = False, + max_run_count: Optional[int] = None, + max_concurrent_run_count: int = 1, + service_account: Optional[str] = None, + network: Optional[str] = None, + create_request_timeout: Optional[float] = None, + ) -> None: + """Helper method to create the PipelineJobSchedule. + + Args: + cron_expression (str): + Required. Time specification (cron schedule expression) to launch scheduled runs. + To explicitly set a timezone to the cron tab, apply a prefix: "CRON_TZ=${IANA_TIME_ZONE}" or "TZ=${IANA_TIME_ZONE}". + The ${IANA_TIME_ZONE} may only be a valid string from IANA time zone database. + For example, "CRON_TZ=America/New_York 1 * * * *", or "TZ=America/New_York 1 * * * *". + start_time (str): + Optional. Timestamp after which the first run can be scheduled. + If unspecified, it defaults to the schedule creation timestamp. + end_time (str): + Optional. Timestamp after which no more runs will be scheduled. + If unspecified, then runs will be scheduled indefinitely. + allow_queueing (bool): + Optional. Whether new scheduled runs can be queued when max_concurrent_runs limit is reached. + max_run_count (int): + Optional. Maximum run count of the schedule. + If specified, The schedule will be completed when either started_run_count >= max_run_count or when end_time is reached. + max_concurrent_run_count (int): + Optional. Maximum number of runs that can be started concurrently for this PipelineJobSchedule. + service_account (str): + Optional. Specifies the service account for workload run-as account. + Users submitting jobs must have act-as permission on this run-as account. + network (str): + Optional. The full name of the Compute Engine network to which the job + should be peered. For example, projects/12345/global/networks/myVPC. + Private services access must already be configured for the network. + If left unspecified, the network set in aiplatform.init will be used. + Otherwise, the job is not peered with any network. + create_request_timeout (float): + Optional. The timeout for the create request in seconds. + """ + if cron_expression: + self._gca_resource.cron = cron_expression + if start_time: + self._gca_resource.start_time = start_time + if end_time: + self._gca_resource.end_time = end_time + if allow_queueing: + self._gca_resource.allow_queueing = allow_queueing + if max_run_count: + self._gca_resource.max_run_count = max_run_count + if max_concurrent_run_count: + self._gca_resource.max_concurrent_run_count = max_concurrent_run_count + + network = network or initializer.global_config.network + + if service_account: + self._gca_resource.create_pipeline_job_request.pipeline_job.service_account = ( + service_account + ) + + if network: + self._gca_resource.create_pipeline_job_request.pipeline_job.network = ( + network + ) + + _LOGGER.log_create_with_lro(self.__class__) + + self._gca_resource = self.api_client.create_schedule( + parent=self._parent, + schedule=self._gca_resource, + timeout=create_request_timeout, + ) + + _LOGGER.log_create_complete_with_getter( + self.__class__, self._gca_resource, "schedule" + ) + + _LOGGER.info("View Schedule:\n%s" % self._dashboard_uri()) diff --git a/google/cloud/aiplatform/preview/schedule/schedules.py b/google/cloud/aiplatform/preview/schedule/schedules.py new file mode 100644 index 0000000000..400c7be82f --- /dev/null +++ b/google/cloud/aiplatform/preview/schedule/schedules.py @@ -0,0 +1,170 @@ +# -*- coding: utf-8 -*- + +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import time +from typing import Any, Optional + +from google.auth import credentials as auth_credentials +from google.cloud.aiplatform import base +from google.cloud.aiplatform import utils +from google.cloud.aiplatform.compat.types import ( + schedule_v1beta1 as gca_schedule, +) +from google.cloud.aiplatform.preview.constants import ( + schedules as schedule_constants, +) + +_LOGGER = base.Logger(__name__) + +_SCHEDULE_COMPLETE_STATES = schedule_constants._SCHEDULE_COMPLETE_STATES + +_SCHEDULE_ERROR_STATES = schedule_constants._SCHEDULE_ERROR_STATES + + +class _Schedule( + base.VertexAiStatefulResource, +): + """Preview Schedule resource for Vertex AI.""" + + client_class = utils.ScheduleClientWithOverride + _resource_noun = "schedules" + _delete_method = "delete_schedule" + _getter_method = "get_schedule" + _list_method = "list_schedules" + _pause_method = "pause_schedule" + _resume_method = "resume_schedule" + _parse_resource_name_method = "parse_schedule_path" + _format_resource_name_method = "schedule_path" + + # Required by the done() method + _valid_done_states = schedule_constants._SCHEDULE_COMPLETE_STATES + + def __init__( + self, + credentials: auth_credentials.Credentials, + project: str, + location: str, + ): + """Retrieves a Schedule resource and instantiates its representation. + + Args: + credentials (auth_credentials.Credentials): + Optional. Custom credentials to use to create this Schedule. + Overrides credentials set in aiplatform.init. + project (str): + Optional. The project that you want to run this Schedule in. + If not set, the project set in aiplatform.init will be used. + location (str): + Optional. Location to create Schedule. If not set, + location set in aiplatform.init will be used. + """ + super().__init__(project=project, location=location, credentials=credentials) + + @classmethod + def get( + cls, + schedule_id: str, + project: Optional[str] = None, + location: Optional[str] = None, + credentials: Optional[auth_credentials.Credentials] = None, + ) -> Any: + """Get a Vertex AI Schedule for the given resource_name. + + Args: + schedule_id (str): + Required. Schedule ID used to identify or locate the schedule. + project (str): + Optional. Project to retrieve dataset from. If not set, project + set in aiplatform.init will be used. + location (str): + Optional. Location to retrieve dataset from. If not set, + location set in aiplatform.init will be used. + credentials (auth_credentials.Credentials): + Optional. Custom credentials to use to upload this model. + Overrides credentials set in aiplatform.init. + + Returns: + A Vertex AI Schedule. + """ + self = cls._empty_constructor( + project=project, + location=location, + credentials=credentials, + resource_name=schedule_id, + ) + + self._gca_resource = self._get_gca_resource(resource_name=schedule_id) + + return self + + def wait(self) -> None: + """Wait for this Schedule to complete.""" + if self._latest_future is None: + self._block_until_complete() + else: + super().wait() + + @property + def state(self) -> Optional[gca_schedule.Schedule.State]: + """Current Schedule state. + + Returns: + Schedule state. + """ + self._sync_gca_resource() + return self._gca_resource.state + + def _block_until_complete(self) -> None: + """Helper method to block and check on Schedule until complete.""" + # Used these numbers so failures surface fast + wait = 5 # start at five seconds + log_wait = 5 + max_wait = 60 * 5 # 5 minute wait + multiplier = 2 # scale wait by 2 every iteration + + previous_time = time.time() + while self.state not in _SCHEDULE_COMPLETE_STATES: + current_time = time.time() + if current_time - previous_time >= log_wait: + _LOGGER.info( + "%s %s current state:\n%s" + % ( + self.__class__.__name__, + self._gca_resource.name, + self._gca_resource.state, + ) + ) + log_wait = min(log_wait * multiplier, max_wait) + previous_time = current_time + time.sleep(wait) + + # Error is only populated when the schedule state is STATE_UNSPECIFIED. + if self._gca_resource.state in _SCHEDULE_ERROR_STATES: + raise RuntimeError("Schedule failed with:\n%s" % self._gca_resource.error) + else: + _LOGGER.log_action_completed_against_resource("run", "completed", self) + + def _dashboard_uri(self) -> str: + """Helper method to compose the dashboard uri where Schedule can be + viewed. + + Returns: + Dashboard uri where Schedule can be viewed. + """ + fields = self._parse_resource_name(self.resource_name) + url = f"https://console.cloud.google.com/vertex-ai/locations/{fields['location']}/pipelines/runs/{fields['schedule']}?project={fields['project']}" + return url diff --git a/google/cloud/aiplatform/utils/__init__.py b/google/cloud/aiplatform/utils/__init__.py index e0faf2460b..6402a0edc3 100644 --- a/google/cloud/aiplatform/utils/__init__.py +++ b/google/cloud/aiplatform/utils/__init__.py @@ -49,6 +49,7 @@ model_service_client_v1beta1, pipeline_service_client_v1beta1, prediction_service_client_v1beta1, + schedule_service_client_v1beta1, tensorboard_service_client_v1beta1, vizier_service_client_v1beta1, model_garden_service_client_v1beta1, @@ -89,6 +90,7 @@ job_service_client_v1beta1.JobServiceClient, match_service_client_v1beta1.MatchServiceClient, metadata_service_client_v1beta1.MetadataServiceClient, + schedule_service_client_v1beta1.ScheduleServiceClient, tensorboard_service_client_v1beta1.TensorboardServiceClient, vizier_service_client_v1beta1.VizierServiceClient, # v1 @@ -592,6 +594,14 @@ class PipelineJobClientWithOverride(ClientWithOverride): ) +class ScheduleClientWithOverride(ClientWithOverride): + _is_temporary = True + _default_version = compat.V1BETA1 + _version_map = ( + (compat.V1BETA1, schedule_service_client_v1beta1.ScheduleServiceClient), + ) + + class PredictionClientWithOverride(ClientWithOverride): _is_temporary = False _default_version = compat.DEFAULT_VERSION @@ -654,6 +664,7 @@ class ModelGardenClientWithOverride(ClientWithOverride): PipelineJobClientWithOverride, PredictionClientWithOverride, MetadataClientWithOverride, + ScheduleClientWithOverride, TensorboardClientWithOverride, VizierClientWithOverride, ModelGardenClientWithOverride, diff --git a/tests/unit/aiplatform/test_pipeline_job_schedules.py b/tests/unit/aiplatform/test_pipeline_job_schedules.py new file mode 100644 index 0000000000..7a22af103d --- /dev/null +++ b/tests/unit/aiplatform/test_pipeline_job_schedules.py @@ -0,0 +1,432 @@ +# -*- coding: utf-8 -*- + +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from datetime import datetime +from importlib import reload +import json +from unittest import mock +from unittest.mock import patch +from urllib import request +import yaml + +from google.auth import credentials as auth_credentials +from google.cloud import storage +from google.cloud import aiplatform +from google.cloud.aiplatform import base +from google.cloud.aiplatform import initializer +from google.cloud.aiplatform.compat.services import ( + schedule_service_client_v1beta1 as schedule_service_client, +) +from google.cloud.aiplatform.compat.types import ( + pipeline_job_v1beta1 as gca_pipeline_job, + schedule_v1beta1 as gca_schedule, +) +from google.cloud.aiplatform.preview.constants import ( + schedules as schedule_constants, +) +from google.cloud.aiplatform import pipeline_jobs +from google.cloud.aiplatform.preview.pipelinejobschedule import ( + pipeline_job_schedules, +) +from google.cloud.aiplatform.utils import gcs_utils +import pytest + +from google.protobuf import field_mask_pb2 as field_mask +from google.protobuf import json_format + +_TEST_PROJECT = "test-project" +_TEST_LOCATION = "us-central1" +_TEST_PIPELINE_JOB_DISPLAY_NAME = "sample-pipeline-job-display-name" +_TEST_GCS_BUCKET_NAME = "my-bucket" +_TEST_GCS_OUTPUT_DIRECTORY = f"gs://{_TEST_GCS_BUCKET_NAME}/output_artifacts/" +_TEST_CREDENTIALS = auth_credentials.AnonymousCredentials() +_TEST_SERVICE_ACCOUNT = "abcde@my-project.iam.gserviceaccount.com" + +_TEST_PIPELINE_JOB_SCHEDULE_DISPLAY_NAME = "sample-pipeline-job-schedule-display-name" +_TEST_PIPELINE_JOB_SCHEDULE_ID = "sample-test-schedule-20230417" +_TEST_PIPELINE_JOB_SCHEDULE_NAME = f"projects/{_TEST_PROJECT}/locations/{_TEST_LOCATION}/schedules/{_TEST_PIPELINE_JOB_SCHEDULE_ID}" +_TEST_PIPELINE_JOB_SCHEDULE_CRON_EXPRESSION = "* * * * *" +_TEST_PIPELINE_JOB_SCHEDULE_CRON_TZ_EXPRESSION = "TZ=America/New_York * * * * *" +_TEST_PIPELINE_JOB_SCHEDULE_MAX_CONCURRENT_RUN_COUNT = 1 +_TEST_PIPELINE_JOB_SCHEDULE_MAX_RUN_COUNT = 2 + +_TEST_PIPELINE_JOB_SCHEDULE_LIST_READ_MASK = field_mask.FieldMask( + paths=schedule_constants._PIPELINE_JOB_SCHEDULE_READ_MASK_FIELDS +) + +_TEST_TEMPLATE_PATH = f"gs://{_TEST_GCS_BUCKET_NAME}/job_spec.json" +_TEST_AR_TEMPLATE_PATH = "https://us-central1-kfp.pkg.dev/proj/repo/pack/latest" +_TEST_HTTPS_TEMPLATE_PATH = "https://raw.githubusercontent.com/repo/pipeline.json" +_TEST_PARENT = f"projects/{_TEST_PROJECT}/locations/{_TEST_LOCATION}" +_TEST_NETWORK = ( + f"projects/{_TEST_PROJECT}/global/networks/{_TEST_PIPELINE_JOB_SCHEDULE_ID}" +) + +_TEST_PIPELINE_PARAMETER_VALUES_LEGACY = {"string_param": "hello"} +_TEST_PIPELINE_PARAMETER_VALUES = { + "string_param": "hello world", + "bool_param": True, + "double_param": 12.34, + "int_param": 5678, + "list_int_param": [123, 456, 789], + "list_string_param": ["lorem", "ipsum"], + "struct_param": {"key1": 12345, "key2": 67890}, +} + +_TEST_PIPELINE_INPUT_ARTIFACTS = { + "vertex_model": "456", +} + +_TEST_PIPELINE_SPEC_LEGACY_JSON = json.dumps( + { + "pipelineInfo": {"name": "my-pipeline"}, + "root": { + "dag": {"tasks": {}}, + "inputDefinitions": {"parameters": {"string_param": {"type": "STRING"}}}, + }, + "schemaVersion": "2.0.0", + "components": {}, + } +) +_TEST_PIPELINE_SPEC_LEGACY_YAML = """\ +pipelineInfo: + name: my-pipeline +root: + dag: + tasks: {} + inputDefinitions: + parameters: + string_param: + type: STRING +schemaVersion: 2.0.0 +components: {} +""" +_TEST_PIPELINE_SPEC_JSON = json.dumps( + { + "pipelineInfo": {"name": "my-pipeline"}, + "root": { + "dag": {"tasks": {}}, + "inputDefinitions": { + "parameters": { + "string_param": {"parameterType": "STRING"}, + "bool_param": {"parameterType": "BOOLEAN"}, + "double_param": {"parameterType": "NUMBER_DOUBLE"}, + "int_param": {"parameterType": "NUMBER_INTEGER"}, + "list_int_param": {"parameterType": "LIST"}, + "list_string_param": {"parameterType": "LIST"}, + "struct_param": {"parameterType": "STRUCT"}, + } + }, + }, + "schemaVersion": "2.1.0", + "components": {}, + } +) +_TEST_PIPELINE_SPEC_YAML = """\ +pipelineInfo: + name: my-pipeline +root: + dag: + tasks: {} + inputDefinitions: + parameters: + string_param: + parameterType: STRING + bool_param: + parameterType: BOOLEAN + double_param: + parameterType: NUMBER_DOUBLE + int_param: + parameterType: NUMBER_INTEGER + list_int_param: + parameterType: LIST + list_string_param: + parameterType: LIST + struct_param: + parameterType: STRUCT +schemaVersion: 2.1.0 +components: {} +""" +_TEST_TFX_PIPELINE_SPEC_JSON = json.dumps( + { + "pipelineInfo": {"name": "my-pipeline"}, + "root": { + "dag": {"tasks": {}}, + "inputDefinitions": {"parameters": {"string_param": {"type": "STRING"}}}, + }, + "schemaVersion": "2.0.0", + "sdkVersion": "tfx-1.4.0", + "components": {}, + } +) +_TEST_TFX_PIPELINE_SPEC_YAML = """\ +pipelineInfo: + name: my-pipeline +root: + dag: + tasks: {} + inputDefinitions: + parameters: + string_param: + type: STRING +schemaVersion: 2.0.0 +sdkVersion: tfx-1.4.0 +components: {} +""" + +_TEST_PIPELINE_JOB_LEGACY = json.dumps( + {"runtimeConfig": {}, "pipelineSpec": json.loads(_TEST_PIPELINE_SPEC_LEGACY_JSON)} +) +_TEST_PIPELINE_JOB = json.dumps( + { + "runtimeConfig": {"parameter_values": _TEST_PIPELINE_PARAMETER_VALUES}, + "pipelineSpec": json.loads(_TEST_PIPELINE_SPEC_JSON), + } +) +_TEST_PIPELINE_JOB_TFX = json.dumps( + {"runtimeConfig": {}, "pipelineSpec": json.loads(_TEST_TFX_PIPELINE_SPEC_JSON)} +) + +_TEST_CREATE_PIPELINE_JOB_REQUEST = { + "parent": _TEST_PARENT, + "pipeline_job": { + "runtime_config": {"parameter_values": _TEST_PIPELINE_PARAMETER_VALUES}, + "pipeline_spec": json.loads(_TEST_PIPELINE_SPEC_JSON), + }, +} + + +_TEST_SCHEDULE_GET_METHOD_NAME = "get_fake_schedule" +_TEST_SCHEDULE_LIST_METHOD_NAME = "list_fake_schedules" +_TEST_SCHEDULE_CANCEL_METHOD_NAME = "cancel_fake_schedule" +_TEST_SCHEDULE_DELETE_METHOD_NAME = "delete_fake_schedule" + +_TEST_PIPELINE_CREATE_TIME = datetime.now() + + +@pytest.fixture +def mock_schedule_service_create(): + with mock.patch.object( + schedule_service_client.ScheduleServiceClient, "create_schedule" + ) as mock_create_schedule: + mock_create_schedule.return_value = gca_schedule.Schedule( + name=_TEST_PIPELINE_JOB_SCHEDULE_NAME, + state=gca_schedule.Schedule.State.COMPLETED, + create_time=_TEST_PIPELINE_CREATE_TIME, + cron=_TEST_PIPELINE_JOB_SCHEDULE_CRON_EXPRESSION, + max_concurrent_run_count=_TEST_PIPELINE_JOB_SCHEDULE_MAX_CONCURRENT_RUN_COUNT, + max_run_count=_TEST_PIPELINE_JOB_SCHEDULE_MAX_RUN_COUNT, + create_pipeline_job_request=_TEST_CREATE_PIPELINE_JOB_REQUEST, + ) + yield mock_create_schedule + + +@pytest.fixture +def mock_schedule_bucket_exists(): + def mock_create_gcs_bucket_for_pipeline_artifacts_if_it_does_not_exist( + output_artifacts_gcs_dir=None, + service_account=None, + project=None, + location=None, + credentials=None, + ): + output_artifacts_gcs_dir = ( + output_artifacts_gcs_dir + or gcs_utils.generate_gcs_directory_for_pipeline_artifacts( + project=project, + location=location, + ) + ) + return output_artifacts_gcs_dir + + with mock.patch( + "google.cloud.aiplatform.utils.gcs_utils.create_gcs_bucket_for_pipeline_artifacts_if_it_does_not_exist", + wraps=mock_create_gcs_bucket_for_pipeline_artifacts_if_it_does_not_exist, + ) as mock_context: + yield mock_context + + +def make_schedule(state): + return gca_schedule.Schedule( + name=_TEST_PIPELINE_JOB_SCHEDULE_NAME, + state=state, + create_time=_TEST_PIPELINE_CREATE_TIME, + cron=_TEST_PIPELINE_JOB_SCHEDULE_CRON_EXPRESSION, + max_concurrent_run_count=_TEST_PIPELINE_JOB_SCHEDULE_MAX_CONCURRENT_RUN_COUNT, + max_run_count=_TEST_PIPELINE_JOB_SCHEDULE_MAX_RUN_COUNT, + create_pipeline_job_request=_TEST_CREATE_PIPELINE_JOB_REQUEST, + ) + + +@pytest.fixture +def mock_schedule_service_get(): + with mock.patch.object( + schedule_service_client.ScheduleServiceClient, "get_schedule" + ) as mock_get_schedule: + mock_get_schedule.side_effect = [ + make_schedule(gca_schedule.Schedule.State.ACTIVE), + make_schedule(gca_schedule.Schedule.State.COMPLETED), + make_schedule(gca_schedule.Schedule.State.COMPLETED), + make_schedule(gca_schedule.Schedule.State.COMPLETED), + make_schedule(gca_schedule.Schedule.State.COMPLETED), + make_schedule(gca_schedule.Schedule.State.COMPLETED), + make_schedule(gca_schedule.Schedule.State.COMPLETED), + make_schedule(gca_schedule.Schedule.State.COMPLETED), + make_schedule(gca_schedule.Schedule.State.COMPLETED), + ] + + yield mock_get_schedule + + +@pytest.fixture +def mock_schedule_service_get_with_fail(): + with mock.patch.object( + schedule_service_client.ScheduleServiceClient, "get_schedule" + ) as mock_get_schedule: + mock_get_schedule.side_effect = [ + make_schedule(gca_schedule.Schedule.State.ACTIVE), + make_schedule(gca_schedule.Schedule.State.ACTIVE), + make_schedule(gca_schedule.Schedule.State.STATE_UNSPECIFIED), + ] + + yield mock_get_schedule + + +@pytest.fixture +def mock_load_yaml_and_json(job_spec): + with patch.object(storage.Blob, "download_as_bytes") as mock_load_yaml_and_json: + mock_load_yaml_and_json.return_value = job_spec.encode() + yield mock_load_yaml_and_json + + +@pytest.fixture +def mock_request_urlopen(job_spec): + with patch.object(request, "urlopen") as mock_urlopen: + mock_read_response = mock.MagicMock() + mock_decode_response = mock.MagicMock() + mock_decode_response.return_value = job_spec.encode() + mock_read_response.return_value.decode = mock_decode_response + mock_urlopen.return_value.read = mock_read_response + yield mock_urlopen + + +@pytest.mark.usefixtures("google_auth_mock") +class TestPipelineJobSchedule: + def setup_method(self): + reload(initializer) + reload(aiplatform) + aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) + + def teardown_method(self): + initializer.global_pool.shutdown(wait=True) + + @pytest.mark.parametrize( + "job_spec", + [_TEST_PIPELINE_SPEC_JSON, _TEST_PIPELINE_SPEC_YAML, _TEST_PIPELINE_JOB], + ) + def test_call_schedule_service_create( + self, + mock_schedule_service_create, + mock_schedule_service_get, + mock_schedule_bucket_exists, + job_spec, + mock_load_yaml_and_json, + ): + """Creates a PipelineJobSchedule. + + Creates PipelineJob with template stored in GCS bucket. + """ + aiplatform.init( + project=_TEST_PROJECT, + staging_bucket=_TEST_GCS_BUCKET_NAME, + location=_TEST_LOCATION, + credentials=_TEST_CREDENTIALS, + ) + + job = pipeline_jobs.PipelineJob( + display_name=_TEST_PIPELINE_JOB_DISPLAY_NAME, + template_path=_TEST_TEMPLATE_PATH, + parameter_values=_TEST_PIPELINE_PARAMETER_VALUES, + input_artifacts=_TEST_PIPELINE_INPUT_ARTIFACTS, + enable_caching=True, + ) + + pipeline_job_schedule = pipeline_job_schedules.PipelineJobSchedule( + pipeline_job=job, + display_name=_TEST_PIPELINE_JOB_SCHEDULE_DISPLAY_NAME, + ) + + pipeline_job_schedule.create( + cron_expression=_TEST_PIPELINE_JOB_SCHEDULE_CRON_EXPRESSION, + max_concurrent_run_count=_TEST_PIPELINE_JOB_SCHEDULE_MAX_CONCURRENT_RUN_COUNT, + max_run_count=_TEST_PIPELINE_JOB_SCHEDULE_MAX_RUN_COUNT, + service_account=_TEST_SERVICE_ACCOUNT, + network=_TEST_NETWORK, + create_request_timeout=None, + ) + + expected_runtime_config_dict = { + "gcsOutputDirectory": _TEST_GCS_BUCKET_NAME, + "parameterValues": _TEST_PIPELINE_PARAMETER_VALUES, + "inputArtifacts": {"vertex_model": {"artifactId": "456"}}, + } + runtime_config = gca_pipeline_job.PipelineJob.RuntimeConfig()._pb + json_format.ParseDict(expected_runtime_config_dict, runtime_config) + + job_spec = yaml.safe_load(job_spec) + pipeline_spec = job_spec.get("pipelineSpec") or job_spec + + # Construct expected request + expected_gapic_pipeline_job_schedule = gca_schedule.Schedule( + display_name=_TEST_PIPELINE_JOB_SCHEDULE_DISPLAY_NAME, + cron=_TEST_PIPELINE_JOB_SCHEDULE_CRON_EXPRESSION, + max_concurrent_run_count=_TEST_PIPELINE_JOB_SCHEDULE_MAX_CONCURRENT_RUN_COUNT, + max_run_count=_TEST_PIPELINE_JOB_SCHEDULE_MAX_RUN_COUNT, + create_pipeline_job_request={ + "parent": _TEST_PARENT, + "pipeline_job": { + "runtime_config": runtime_config, + "pipeline_spec": {"fields": pipeline_spec}, + "service_account": _TEST_SERVICE_ACCOUNT, + "network": _TEST_NETWORK, + }, + }, + ) + + mock_schedule_service_create.assert_called_once_with( + parent=_TEST_PARENT, + schedule=expected_gapic_pipeline_job_schedule, + timeout=None, + ) + + assert pipeline_job_schedule._gca_resource == make_schedule( + gca_schedule.Schedule.State.COMPLETED + ) + + @pytest.mark.usefixtures("mock_schedule_service_get") + def test_get_schedule(self, mock_schedule_service_get): + aiplatform.init(project=_TEST_PROJECT) + pipeline_job_schedule = pipeline_job_schedules.PipelineJobSchedule.get( + schedule_id=_TEST_PIPELINE_JOB_SCHEDULE_ID + ) + + mock_schedule_service_get.assert_called_once_with( + name=_TEST_PIPELINE_JOB_SCHEDULE_NAME, retry=base._DEFAULT_RETRY + ) + assert isinstance( + pipeline_job_schedule, pipeline_job_schedules.PipelineJobSchedule + )