Skip to content

Commit

Permalink
feat: Add Persistent Resource ID parameter to Custom Job form_local_s…
Browse files Browse the repository at this point in the history
…cript, run, and submit methods.

PiperOrigin-RevId: 622310810
  • Loading branch information
vertex-sdk-bot authored and Copybara-Service committed Apr 5, 2024
1 parent 8c6ddf5 commit f5be0b5
Show file tree
Hide file tree
Showing 2 changed files with 140 additions and 2 deletions.
38 changes: 38 additions & 0 deletions google/cloud/aiplatform/jobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -1923,6 +1923,7 @@ def from_local_script(
labels: Optional[Dict[str, str]] = None,
encryption_spec_key_name: Optional[str] = None,
staging_bucket: Optional[str] = None,
persistent_resource_id: Optional[str] = None,
) -> "CustomJob":
"""Configures a custom job from a local script.
Expand Down Expand Up @@ -2026,6 +2027,13 @@ def from_local_script(
staging_bucket (str):
Optional. Bucket for produced custom job artifacts. Overrides
staging_bucket set in aiplatform.init.
persistent_resource_id (str):
Optional. The ID of the PersistentResource in the same Project
and Location. If this is specified, the job will be run on
existing machines held by the PersistentResource instead of
on-demand short-live machines. The network, CMEK, and node pool
configs on the job should be consistent with those on the
PersistentResource, otherwise, the job will be rejected.
Raises:
RuntimeError: If staging bucket was not set using aiplatform.init
Expand Down Expand Up @@ -2171,6 +2179,7 @@ def from_local_script(
labels=labels,
encryption_spec_key_name=encryption_spec_key_name,
staging_bucket=staging_bucket,
persistent_resource_id=persistent_resource_id,
)

if enable_autolog:
Expand All @@ -2191,6 +2200,7 @@ def run(
sync: bool = True,
create_request_timeout: Optional[float] = None,
disable_retries: bool = False,
persistent_resource_id: Optional[str] = None,
) -> None:
"""Run this configured CustomJob.
Expand Down Expand Up @@ -2252,6 +2262,13 @@ def run(
Indicates if the job should retry for internal errors after the
job starts running. If True, overrides
`restart_job_on_worker_restart` to False.
persistent_resource_id (str):
Optional. The ID of the PersistentResource in the same Project
and Location. If this is specified, the job will be run on
existing machines held by the PersistentResource instead of
on-demand short-live machines. The network, CMEK, and node pool
configs on the job should be consistent with those on the
PersistentResource, otherwise, the job will be rejected.
"""
network = network or initializer.global_config.network
service_account = service_account or initializer.global_config.service_account
Expand All @@ -2268,6 +2285,7 @@ def run(
sync=sync,
create_request_timeout=create_request_timeout,
disable_retries=disable_retries,
persistent_resource_id=persistent_resource_id,
)

@base.optional_sync()
Expand All @@ -2284,6 +2302,7 @@ def _run(
sync: bool = True,
create_request_timeout: Optional[float] = None,
disable_retries: bool = False,
persistent_resource_id: Optional[str] = None,
) -> None:
"""Helper method to ensure network synchronization and to run the configured CustomJob.
Expand Down Expand Up @@ -2343,6 +2362,13 @@ def _run(
Indicates if the job should retry for internal errors after the
job starts running. If True, overrides
`restart_job_on_worker_restart` to False.
persistent_resource_id (str):
Optional. The ID of the PersistentResource in the same Project
and Location. If this is specified, the job will be run on
existing machines held by the PersistentResource instead of
on-demand short-live machines. The network, CMEK, and node pool
configs on the job should be consistent with those on the
PersistentResource, otherwise, the job will be rejected.
"""
self.submit(
service_account=service_account,
Expand All @@ -2355,6 +2381,7 @@ def _run(
tensorboard=tensorboard,
create_request_timeout=create_request_timeout,
disable_retries=disable_retries,
persistent_resource_id=persistent_resource_id,
)

self._block_until_complete()
Expand All @@ -2372,6 +2399,7 @@ def submit(
tensorboard: Optional[str] = None,
create_request_timeout: Optional[float] = None,
disable_retries: bool = False,
persistent_resource_id: Optional[str] = None,
) -> None:
"""Submit the configured CustomJob.
Expand Down Expand Up @@ -2428,6 +2456,13 @@ def submit(
Indicates if the job should retry for internal errors after the
job starts running. If True, overrides
`restart_job_on_worker_restart` to False.
persistent_resource_id (str):
Optional. The ID of the PersistentResource in the same Project
and Location. If this is specified, the job will be run on
existing machines held by the PersistentResource instead of
on-demand short-live machines. The network, CMEK, and node pool
configs on the job should be consistent with those on the
PersistentResource, otherwise, the job will be rejected.
Raises:
ValueError:
Expand Down Expand Up @@ -2464,6 +2499,9 @@ def submit(
if tensorboard:
self._gca_resource.job_spec.tensorboard = tensorboard

if persistent_resource_id:
self._gca_resource.job_spec.persistent_resource_id = persistent_resource_id

# TODO(b/275105711) Update implementation after experiment/run in the proto
if experiment:
# short-term solution to set experiment/experimentRun in SDK
Expand Down
104 changes: 102 additions & 2 deletions tests/unit/aiplatform/test_custom_job_persistent_resource.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,15 @@
from google.cloud import aiplatform
from google.cloud.aiplatform import jobs
from google.cloud.aiplatform.compat.services import job_service_client_v1
from google.cloud.aiplatform.compat.types import (
custom_job as gca_custom_job_compat,
)
from google.cloud.aiplatform.compat.types import custom_job_v1
from google.cloud.aiplatform.compat.types import encryption_spec_v1
from google.cloud.aiplatform.compat.types import io_v1
from google.cloud.aiplatform.compat.types import job_state_v1 as gca_job_state_compat
from google.cloud.aiplatform.compat.types import (
job_state_v1 as gca_job_state_compat,
)
import constants as test_constants
import pytest

Expand Down Expand Up @@ -71,6 +76,11 @@

_TEST_LABELS = test_constants.ProjectConstants._TEST_LABELS

_TEST_PYTHON_PACKAGE_SPEC = gca_custom_job_compat.PythonPackageSpec(
executor_image_uri=_TEST_PREBUILT_CONTAINER_IMAGE,
package_uris=[test_constants.TrainingJobConstants._TEST_OUTPUT_PYTHON_PACKAGE_PATH],
python_module=test_constants.TrainingJobConstants._TEST_MODULE_NAME,
)

# Persistent Resource
_TEST_PERSISTENT_RESOURCE_ID = "test-persistent-resource-1"
Expand Down Expand Up @@ -212,7 +222,6 @@ def test_submit_custom_job_with_persistent_resource(
worker_pool_specs=_TEST_WORKER_POOL_SPEC,
base_output_dir=_TEST_BASE_OUTPUT_DIR,
labels=_TEST_LABELS,
persistent_resource_id=_TEST_PERSISTENT_RESOURCE_ID,
)

job.submit(
Expand All @@ -222,6 +231,7 @@ def test_submit_custom_job_with_persistent_resource(
restart_job_on_worker_restart=_TEST_RESTART_JOB_ON_WORKER_RESTART,
create_request_timeout=None,
disable_retries=_TEST_DISABLE_RETRIES,
persistent_resource_id=_TEST_PERSISTENT_RESOURCE_ID,
)

job.wait_for_resource_creation()
Expand All @@ -243,3 +253,93 @@ def test_submit_custom_job_with_persistent_resource(
job._gca_resource.state == gca_job_state_compat.JobState.JOB_STATE_PENDING
)
assert job.network == _TEST_NETWORK

@pytest.mark.parametrize("sync", [True, False])
def test_run_custom_job_with_persistent_resource(
self, create_custom_job_mock, get_custom_job_mock, sync
):

aiplatform.init(
project=_TEST_PROJECT,
location=_TEST_LOCATION,
staging_bucket=_TEST_STAGING_BUCKET,
encryption_spec_key_name=_TEST_DEFAULT_ENCRYPTION_KEY_NAME,
)

job = jobs.CustomJob(
display_name=_TEST_DISPLAY_NAME,
worker_pool_specs=_TEST_WORKER_POOL_SPEC,
base_output_dir=_TEST_BASE_OUTPUT_DIR,
labels=_TEST_LABELS,
)

job.run(
service_account=_TEST_SERVICE_ACCOUNT,
network=_TEST_NETWORK,
timeout=_TEST_TIMEOUT,
restart_job_on_worker_restart=_TEST_RESTART_JOB_ON_WORKER_RESTART,
create_request_timeout=None,
disable_retries=_TEST_DISABLE_RETRIES,
sync=sync,
persistent_resource_id=_TEST_PERSISTENT_RESOURCE_ID,
)

job.wait_for_resource_creation()

assert job.resource_name == _TEST_CUSTOM_JOB_NAME

job.wait()

expected_custom_job = _get_custom_job_proto()

create_custom_job_mock.assert_called_once_with(
parent=_TEST_PARENT,
custom_job=expected_custom_job,
timeout=None,
)

assert job.job_spec == expected_custom_job.job_spec
assert (
job._gca_resource.state == gca_job_state_compat.JobState.JOB_STATE_SUCCEEDED
)
assert job.network == _TEST_NETWORK

@pytest.mark.usefixtures("mock_python_package_to_gcs")
@pytest.mark.parametrize("sync", [True, False])
def test_from_local_script_custom_job_with_persistent_resource(
self, create_custom_job_mock, get_custom_job_mock, sync
):

aiplatform.init(
project=_TEST_PROJECT,
location=_TEST_LOCATION,
staging_bucket=_TEST_STAGING_BUCKET,
encryption_spec_key_name=_TEST_DEFAULT_ENCRYPTION_KEY_NAME,
)

job = jobs.CustomJob.from_local_script(
display_name=_TEST_DISPLAY_NAME,
script_path=test_constants.TrainingJobConstants._TEST_LOCAL_SCRIPT_FILE_NAME,
container_uri=_TEST_PREBUILT_CONTAINER_IMAGE,
base_output_dir=_TEST_BASE_OUTPUT_DIR,
labels=_TEST_LABELS,
persistent_resource_id=_TEST_PERSISTENT_RESOURCE_ID,
)

assert (
job.job_spec.worker_pool_specs[0].python_package_spec
== _TEST_PYTHON_PACKAGE_SPEC
)

job.run(sync=sync)

job.wait_for_resource_creation()

assert job.resource_name == _TEST_CUSTOM_JOB_NAME

job.wait()

assert job.job_spec.persistent_resource_id == _TEST_PERSISTENT_RESOURCE_ID
assert (
job._gca_resource.state == gca_job_state_compat.JobState.JOB_STATE_SUCCEEDED
)

0 comments on commit f5be0b5

Please sign in to comment.