Skip to content

Commit

Permalink
feat: Add Persistent Resource Id parameter to Custom Training Job run…
Browse files Browse the repository at this point in the history
… and submit methods.

PiperOrigin-RevId: 622311949
  • Loading branch information
vertex-sdk-bot authored and Copybara-Service committed Apr 5, 2024
1 parent f5be0b5 commit f428006
Show file tree
Hide file tree
Showing 2 changed files with 322 additions and 0 deletions.
83 changes: 83 additions & 0 deletions google/cloud/aiplatform/training_jobs.py
Expand Up @@ -1489,6 +1489,7 @@ def _prepare_training_task_inputs_and_output_dir(
enable_dashboard_access: bool = False,
tensorboard: Optional[str] = None,
disable_retries: bool = False,
persistent_resource_id: Optional[str] = None,
) -> Tuple[Dict, str]:
"""Prepares training task inputs and output directory for custom job.
Expand Down Expand Up @@ -1539,6 +1540,14 @@ def _prepare_training_task_inputs_and_output_dir(
Indicates if the job should retry for internal errors after the
job starts running. If True, overrides
`restart_job_on_worker_restart` to False.
persistent_resource_id (str):
Optional. The ID of the PersistentResource in the same Project
and Location. If this is specified, the job will be run on
existing machines held by the PersistentResource instead of
on-demand short-live machines. The network, CMEK, and node pool
configs on the job should be consistent with those on the
PersistentResource, otherwise, the job will be rejected.
Returns:
Training task inputs and Output directory for custom job.
"""
Expand All @@ -1565,6 +1574,8 @@ def _prepare_training_task_inputs_and_output_dir(
training_task_inputs["enable_web_access"] = enable_web_access
if enable_dashboard_access:
training_task_inputs["enable_dashboard_access"] = enable_dashboard_access
if persistent_resource_id:
training_task_inputs["persistent_resource_id"] = persistent_resource_id

if timeout or restart_job_on_worker_restart or disable_retries:
timeout = f"{timeout}s" if timeout else None
Expand Down Expand Up @@ -2962,6 +2973,7 @@ def run(
sync=True,
create_request_timeout: Optional[float] = None,
disable_retries: bool = False,
persistent_resource_id: Optional[str] = None,
) -> Optional[models.Model]:
"""Runs the custom training job.
Expand Down Expand Up @@ -3249,6 +3261,13 @@ def run(
Indicates if the job should retry for internal errors after the
job starts running. If True, overrides
`restart_job_on_worker_restart` to False.
persistent_resource_id (str):
Optional. The ID of the PersistentResource in the same Project
and Location. If this is specified, the job will be run on
existing machines held by the PersistentResource instead of
on-demand short-live machines. The network, CMEK, and node pool
configs on the job should be consistent with those on the
PersistentResource, otherwise, the job will be rejected.
Returns:
model: The trained Vertex AI Model resource or None if training did not
Expand Down Expand Up @@ -3311,6 +3330,7 @@ def run(
sync=sync,
create_request_timeout=create_request_timeout,
disable_retries=disable_retries,
persistent_resource_id=persistent_resource_id,
)

def submit(
Expand Down Expand Up @@ -3362,6 +3382,7 @@ def submit(
sync=True,
create_request_timeout: Optional[float] = None,
disable_retries: bool = False,
persistent_resource_id: Optional[str] = None,
) -> Optional[models.Model]:
"""Submits the custom training job without blocking until completion.
Expand Down Expand Up @@ -3649,6 +3670,13 @@ def submit(
Indicates if the job should retry for internal errors after the
job starts running. If True, overrides
`restart_job_on_worker_restart` to False.
persistent_resource_id (str):
Optional. The ID of the PersistentResource in the same Project
and Location. If this is specified, the job will be run on
existing machines held by the PersistentResource instead of
on-demand short-live machines. The network, CMEK, and node pool
configs on the job should be consistent with those on the
PersistentResource, otherwise, the job will be rejected.
Returns:
model: The trained Vertex AI Model resource or None if training did not
Expand Down Expand Up @@ -3711,6 +3739,7 @@ def submit(
create_request_timeout=create_request_timeout,
block=False,
disable_retries=disable_retries,
persistent_resource_id=persistent_resource_id,
)

@base.optional_sync(construct_object_on_arg="managed_model")
Expand Down Expand Up @@ -3757,6 +3786,7 @@ def _run(
create_request_timeout: Optional[float] = None,
block: Optional[bool] = True,
disable_retries: bool = False,
persistent_resource_id: Optional[str] = None,
) -> Optional[models.Model]:
"""Packages local script and launches training_job.
Expand Down Expand Up @@ -3946,6 +3976,13 @@ def _run(
Indicates if the job should retry for internal errors after the
job starts running. If True, overrides
`restart_job_on_worker_restart` to False.
persistent_resource_id (str):
Optional. The ID of the PersistentResource in the same Project
and Location. If this is specified, the job will be run on
existing machines held by the PersistentResource instead of
on-demand short-live machines. The network, CMEK, and node pool
configs on the job should be consistent with those on the
PersistentResource, otherwise, the job will be rejected.
Returns:
model: The trained Vertex AI Model resource or None if training did not
Expand Down Expand Up @@ -3999,6 +4036,7 @@ def _run(
enable_dashboard_access=enable_dashboard_access,
tensorboard=tensorboard,
disable_retries=disable_retries,
persistent_resource_id=persistent_resource_id,
)

model = self._run_job(
Expand Down Expand Up @@ -4321,6 +4359,7 @@ def run(
sync=True,
create_request_timeout: Optional[float] = None,
disable_retries: bool = False,
persistent_resource_id: Optional[str] = None,
) -> Optional[models.Model]:
"""Runs the custom training job.
Expand Down Expand Up @@ -4601,6 +4640,13 @@ def run(
Indicates if the job should retry for internal errors after the
job starts running. If True, overrides
`restart_job_on_worker_restart` to False.
persistent_resource_id (str):
Optional. The ID of the PersistentResource in the same Project
and Location. If this is specified, the job will be run on
existing machines held by the PersistentResource instead of
on-demand short-live machines. The network, CMEK, and node pool
configs on the job should be consistent with those on the
PersistentResource, otherwise, the job will be rejected.
Returns:
model: The trained Vertex AI Model resource or None if training did not
Expand Down Expand Up @@ -4662,6 +4708,7 @@ def run(
sync=sync,
create_request_timeout=create_request_timeout,
disable_retries=disable_retries,
persistent_resource_id=persistent_resource_id,
)

def submit(
Expand Down Expand Up @@ -4713,6 +4760,7 @@ def submit(
sync=True,
create_request_timeout: Optional[float] = None,
disable_retries: bool = False,
persistent_resource_id: Optional[str] = None,
) -> Optional[models.Model]:
"""Submits the custom training job without blocking until completion.
Expand Down Expand Up @@ -4993,6 +5041,13 @@ def submit(
Indicates if the job should retry for internal errors after the
job starts running. If True, overrides
`restart_job_on_worker_restart` to False.
persistent_resource_id (str):
Optional. The ID of the PersistentResource in the same Project
and Location. If this is specified, the job will be run on
existing machines held by the PersistentResource instead of
on-demand short-live machines. The network, CMEK, and node pool
configs on the job should be consistent with those on the
PersistentResource, otherwise, the job will be rejected.
Returns:
model: The trained Vertex AI Model resource or None if training did not
Expand Down Expand Up @@ -5054,6 +5109,7 @@ def submit(
create_request_timeout=create_request_timeout,
block=False,
disable_retries=disable_retries,
persistent_resource_id=persistent_resource_id,
)

@base.optional_sync(construct_object_on_arg="managed_model")
Expand Down Expand Up @@ -5099,6 +5155,7 @@ def _run(
create_request_timeout: Optional[float] = None,
block: Optional[bool] = True,
disable_retries: bool = False,
persistent_resource_id: Optional[str] = None,
) -> Optional[models.Model]:
"""Packages local script and launches training_job.
Args:
Expand Down Expand Up @@ -5284,6 +5341,13 @@ def _run(
Indicates if the job should retry for internal errors after the
job starts running. If True, overrides
`restart_job_on_worker_restart` to False.
persistent_resource_id (str):
Optional. The ID of the PersistentResource in the same Project
and Location. If this is specified, the job will be run on
existing machines held by the PersistentResource instead of
on-demand short-live machines. The network, CMEK, and node pool
configs on the job should be consistent with those on the
PersistentResource, otherwise, the job will be rejected.
Returns:
model: The trained Vertex AI Model resource or None if training did not
Expand Down Expand Up @@ -5331,6 +5395,7 @@ def _run(
enable_dashboard_access=enable_dashboard_access,
tensorboard=tensorboard,
disable_retries=disable_retries,
persistent_resource_id=persistent_resource_id,
)

model = self._run_job(
Expand Down Expand Up @@ -7249,6 +7314,7 @@ def run(
sync=True,
create_request_timeout: Optional[float] = None,
disable_retries: bool = False,
persistent_resource_id: Optional[str] = None,
) -> Optional[models.Model]:
"""Runs the custom training job.
Expand Down Expand Up @@ -7530,6 +7596,13 @@ def run(
Indicates if the job should retry for internal errors after the
job starts running. If True, overrides
`restart_job_on_worker_restart` to False.
persistent_resource_id (str):
Optional. The ID of the PersistentResource in the same Project
and Location. If this is specified, the job will be run on
existing machines held by the PersistentResource instead of
on-demand short-live machines. The network, CMEK, and node pool
configs on the job should be consistent with those on the
PersistentResource, otherwise, the job will be rejected.
Returns:
model: The trained Vertex AI Model resource or None if training did not
Expand Down Expand Up @@ -7586,6 +7659,7 @@ def run(
sync=sync,
create_request_timeout=create_request_timeout,
disable_retries=disable_retries,
persistent_resource_id=persistent_resource_id,
)

@base.optional_sync(construct_object_on_arg="managed_model")
Expand Down Expand Up @@ -7630,6 +7704,7 @@ def _run(
sync=True,
create_request_timeout: Optional[float] = None,
disable_retries: bool = False,
persistent_resource_id: Optional[str] = None,
) -> Optional[models.Model]:
"""Packages local script and launches training_job.
Expand Down Expand Up @@ -7800,6 +7875,13 @@ def _run(
Indicates if the job should retry for internal errors after the
job starts running. If True, overrides
`restart_job_on_worker_restart` to False.
persistent_resource_id (str):
Optional. The ID of the PersistentResource in the same Project
and Location. If this is specified, the job will be run on
existing machines held by the PersistentResource instead of
on-demand short-live machines. The network, CMEK, and node pool
configs on the job should be consistent with those on the
PersistentResource, otherwise, the job will be rejected.
Returns:
model: The trained Vertex AI Model resource or None if training did not
Expand Down Expand Up @@ -7847,6 +7929,7 @@ def _run(
enable_dashboard_access=enable_dashboard_access,
tensorboard=tensorboard,
disable_retries=disable_retries,
persistent_resource_id=persistent_resource_id,
)

model = self._run_job(
Expand Down

0 comments on commit f428006

Please sign in to comment.