diff --git a/google/cloud/aiplatform/preview/pipelinejobschedule/pipeline_job_schedules.py b/google/cloud/aiplatform/preview/pipelinejobschedule/pipeline_job_schedules.py index 85def084a2..5a50dd8200 100644 --- a/google/cloud/aiplatform/preview/pipelinejobschedule/pipeline_job_schedules.py +++ b/google/cloud/aiplatform/preview/pipelinejobschedule/pipeline_job_schedules.py @@ -41,7 +41,6 @@ ) from google.protobuf import field_mask_pb2 as field_mask - _LOGGER = base.Logger(__name__) # Pattern for valid names used as a Vertex resource name. @@ -53,6 +52,8 @@ # Pattern for any JSON or YAML file over HTTPS. _VALID_HTTPS_URL = schedule_constants._VALID_HTTPS_URL +_SCHEDULE_ERROR_STATES = schedule_constants._SCHEDULE_ERROR_STATES + _READ_MASK_FIELDS = schedule_constants._PIPELINE_JOB_SCHEDULE_READ_MASK_FIELDS @@ -385,3 +386,86 @@ def list_jobs( location=location, credentials=credentials, ) + + def update( + self, + display_name: Optional[str] = None, + cron_expression: Optional[str] = None, + start_time: Optional[str] = None, + end_time: Optional[str] = None, + allow_queueing: Optional[bool] = None, + max_run_count: Optional[int] = None, + max_concurrent_run_count: Optional[int] = None, + ) -> None: + """Update an existing PipelineJobSchedule. + + Example usage: + + pipeline_job_schedule.update( + display_name='updated-display-name', + cron_expression='1 2 3 4 5', + ) + + Args: + display_name (str): + Optional. The user-defined name of this PipelineJobSchedule. + cron_expression (str): + Optional. Time specification (cron schedule expression) to launch scheduled runs. + To explicitly set a timezone to the cron tab, apply a prefix: "CRON_TZ=${IANA_TIME_ZONE}" or "TZ=${IANA_TIME_ZONE}". + The ${IANA_TIME_ZONE} may only be a valid string from IANA time zone database. + For example, "CRON_TZ=America/New_York 1 * * * *", or "TZ=America/New_York 1 * * * *". + start_time (str): + Optional. Timestamp after which the first run can be scheduled. + If unspecified, it defaults to the schedule creation timestamp. + end_time (str): + Optional. Timestamp after which no more runs will be scheduled. + If unspecified, then runs will be scheduled indefinitely. + allow_queueing (bool): + Optional. Whether new scheduled runs can be queued when max_concurrent_runs limit is reached. + max_run_count (int): + Optional. Maximum run count of the schedule. + If specified, The schedule will be completed when either started_run_count >= max_run_count or when end_time is reached. + max_concurrent_run_count (int): + Optional. Maximum number of runs that can be started concurrently for this PipelineJobSchedule. + + Raises: + RuntimeError: User tried to call update() before create(). + """ + pipeline_job_schedule = self._gca_resource + if pipeline_job_schedule.state in _SCHEDULE_ERROR_STATES: + raise RuntimeError( + "Not updating PipelineJobSchedule: PipelineJobSchedule must be active or completed." + ) + + updated_fields = [] + if display_name is not None: + updated_fields.append("display_name") + setattr(pipeline_job_schedule, "display_name", display_name) + if cron_expression is not None: + updated_fields.append("cron") + setattr(pipeline_job_schedule, "cron", cron_expression) + if start_time is not None: + updated_fields.append("start_time") + setattr(pipeline_job_schedule, "start_time", start_time) + if end_time is not None: + updated_fields.append("end_time") + setattr(pipeline_job_schedule, "end_time", end_time) + if allow_queueing is not None: + updated_fields.append("allow_queueing") + setattr(pipeline_job_schedule, "allow_queueing", allow_queueing) + if max_run_count is not None: + updated_fields.append("max_run_count") + setattr(pipeline_job_schedule, "max_run_count", max_run_count) + if max_concurrent_run_count is not None: + updated_fields.append("max_concurrent_run_count") + setattr( + pipeline_job_schedule, + "max_concurrent_run_count", + max_concurrent_run_count, + ) + + update_mask = field_mask.FieldMask(paths=updated_fields) + self.api_client.update_schedule( + schedule=pipeline_job_schedule, + update_mask=update_mask, + ) diff --git a/tests/unit/aiplatform/test_pipeline_job_schedules.py b/tests/unit/aiplatform/test_pipeline_job_schedules.py index f1468b4c2a..435a478c0a 100644 --- a/tests/unit/aiplatform/test_pipeline_job_schedules.py +++ b/tests/unit/aiplatform/test_pipeline_job_schedules.py @@ -69,6 +69,9 @@ _TEST_PIPELINE_JOB_SCHEDULE_MAX_CONCURRENT_RUN_COUNT = 1 _TEST_PIPELINE_JOB_SCHEDULE_MAX_RUN_COUNT = 2 +_TEST_UPDATED_PIPELINE_JOB_SCHEDULE_CRON_EXPRESSION = "1 1 1 1 1" +_TEST_UPDATED_PIPELINE_JOB_SCHEDULE_MAX_RUN_COUNT = 5 + _TEST_TEMPLATE_PATH = f"gs://{_TEST_GCS_BUCKET_NAME}/job_spec.json" _TEST_AR_TEMPLATE_PATH = "https://us-central1-kfp.pkg.dev/proj/repo/pack/latest" _TEST_HTTPS_TEMPLATE_PATH = "https://raw.githubusercontent.com/repo/pipeline.json" @@ -371,6 +374,23 @@ def mock_pipeline_service_list(): yield mock_list_pipeline_jobs +@pytest.fixture +def mock_schedule_service_update(): + with mock.patch.object( + schedule_service_client.ScheduleServiceClient, "update_schedule" + ) as mock_update_schedule: + mock_update_schedule.return_value = gca_schedule.Schedule( + name=_TEST_PIPELINE_JOB_SCHEDULE_DISPLAY_NAME, + state=gca_schedule.Schedule.State.COMPLETED, + create_time=_TEST_PIPELINE_CREATE_TIME, + cron=_TEST_UPDATED_PIPELINE_JOB_SCHEDULE_CRON_EXPRESSION, + max_concurrent_run_count=_TEST_PIPELINE_JOB_SCHEDULE_MAX_CONCURRENT_RUN_COUNT, + max_run_count=_TEST_UPDATED_PIPELINE_JOB_SCHEDULE_MAX_RUN_COUNT, + create_pipeline_job_request=_TEST_CREATE_PIPELINE_JOB_REQUEST, + ) + yield mock_update_schedule + + @pytest.fixture def mock_load_yaml_and_json(job_spec): with patch.object(storage.Blob, "download_as_bytes") as mock_load_yaml_and_json: @@ -1304,3 +1324,114 @@ def test_resume_pipeline_job_schedule_without_created( pipeline_job_schedule.resume() assert e.match(regexp=r"Schedule resource has not been created") + + @pytest.mark.parametrize( + "job_spec", + [_TEST_PIPELINE_SPEC_JSON, _TEST_PIPELINE_SPEC_YAML, _TEST_PIPELINE_JOB], + ) + def test_call_schedule_service_update( + self, + mock_schedule_service_create, + mock_schedule_service_update, + mock_schedule_service_get, + mock_schedule_bucket_exists, + job_spec, + mock_load_yaml_and_json, + ): + """Updates a PipelineJobSchedule. + + Updates cron_expression and max_run_count. + """ + aiplatform.init( + project=_TEST_PROJECT, + staging_bucket=_TEST_GCS_BUCKET_NAME, + location=_TEST_LOCATION, + credentials=_TEST_CREDENTIALS, + ) + + job = pipeline_jobs.PipelineJob( + display_name=_TEST_PIPELINE_JOB_DISPLAY_NAME, + template_path=_TEST_TEMPLATE_PATH, + parameter_values=_TEST_PIPELINE_PARAMETER_VALUES, + input_artifacts=_TEST_PIPELINE_INPUT_ARTIFACTS, + enable_caching=True, + ) + + pipeline_job_schedule = pipeline_job_schedules.PipelineJobSchedule( + pipeline_job=job, + display_name=_TEST_PIPELINE_JOB_SCHEDULE_DISPLAY_NAME, + ) + + pipeline_job_schedule.create( + cron_expression=_TEST_PIPELINE_JOB_SCHEDULE_CRON_EXPRESSION, + max_concurrent_run_count=_TEST_PIPELINE_JOB_SCHEDULE_MAX_CONCURRENT_RUN_COUNT, + max_run_count=_TEST_PIPELINE_JOB_SCHEDULE_MAX_RUN_COUNT, + service_account=_TEST_SERVICE_ACCOUNT, + network=_TEST_NETWORK, + create_request_timeout=None, + ) + + pipeline_job_schedule.update( + cron_expression=_TEST_UPDATED_PIPELINE_JOB_SCHEDULE_CRON_EXPRESSION, + max_run_count=_TEST_UPDATED_PIPELINE_JOB_SCHEDULE_MAX_RUN_COUNT, + ) + + expected_gapic_pipeline_job_schedule = gca_schedule.Schedule( + name=_TEST_PIPELINE_JOB_SCHEDULE_NAME, + state=gca_schedule.Schedule.State.COMPLETED, + create_time=_TEST_PIPELINE_CREATE_TIME, + cron=_TEST_UPDATED_PIPELINE_JOB_SCHEDULE_CRON_EXPRESSION, + max_concurrent_run_count=_TEST_PIPELINE_JOB_SCHEDULE_MAX_CONCURRENT_RUN_COUNT, + max_run_count=_TEST_UPDATED_PIPELINE_JOB_SCHEDULE_MAX_RUN_COUNT, + create_pipeline_job_request=_TEST_CREATE_PIPELINE_JOB_REQUEST, + ) + assert ( + pipeline_job_schedule._gca_resource == expected_gapic_pipeline_job_schedule + ) + + @pytest.mark.parametrize( + "job_spec", + [_TEST_PIPELINE_SPEC_JSON, _TEST_PIPELINE_SPEC_YAML, _TEST_PIPELINE_JOB], + ) + def test_call_schedule_service_update_before_create( + self, + mock_schedule_service_create, + mock_schedule_service_update, + mock_schedule_service_get, + mock_schedule_bucket_exists, + job_spec, + mock_load_yaml_and_json, + ): + """Updates a PipelineJobSchedule. + + Raises error because PipelineJobSchedule should be created before update. + """ + aiplatform.init( + project=_TEST_PROJECT, + staging_bucket=_TEST_GCS_BUCKET_NAME, + location=_TEST_LOCATION, + credentials=_TEST_CREDENTIALS, + ) + + job = pipeline_jobs.PipelineJob( + display_name=_TEST_PIPELINE_JOB_DISPLAY_NAME, + template_path=_TEST_TEMPLATE_PATH, + parameter_values=_TEST_PIPELINE_PARAMETER_VALUES, + input_artifacts=_TEST_PIPELINE_INPUT_ARTIFACTS, + enable_caching=True, + ) + + pipeline_job_schedule = pipeline_job_schedules.PipelineJobSchedule( + pipeline_job=job, + display_name=_TEST_PIPELINE_JOB_SCHEDULE_DISPLAY_NAME, + ) + + with pytest.raises(RuntimeError) as e: + pipeline_job_schedule.update( + cron_expression=_TEST_UPDATED_PIPELINE_JOB_SCHEDULE_CRON_EXPRESSION, + max_run_count=_TEST_UPDATED_PIPELINE_JOB_SCHEDULE_MAX_RUN_COUNT, + ) + + assert e.match( + regexp=r"Not updating PipelineJobSchedule: PipelineJobSchedule must be active or completed." + )