Skip to content

Commit

Permalink
chore: Update implementation for experiment proto change in CustomJob
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 635697103
  • Loading branch information
jaycee-li authored and copybara-github committed May 21, 2024
1 parent 352eccf commit 555ead7
Show file tree
Hide file tree
Showing 2 changed files with 128 additions and 119 deletions.
154 changes: 51 additions & 103 deletions google/cloud/aiplatform/jobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,16 +15,16 @@
# limitations under the License.
#

from typing import Iterable, Optional, Union, Sequence, Dict, List
from typing import Iterable, Optional, Union, Sequence, Dict, List, Tuple

import abc
import copy
import datetime
import time
import tempfile
import uuid

from google.auth import credentials as auth_credentials
from google.api_core import exceptions as api_exceptions
from google.protobuf import duration_pb2 # type: ignore
from google.protobuf import field_mask_pb2 # type: ignore
from google.rpc import status_pb2
Expand All @@ -35,6 +35,7 @@
batch_prediction_job as gca_bp_job_compat,
completion_stats as gca_completion_stats,
custom_job as gca_custom_job_compat,
execution as gca_execution_compat,
explanation as gca_explanation_compat,
encryption_spec as gca_encryption_spec_compat,
io as gca_io_compat,
Expand All @@ -61,7 +62,6 @@
batch_prediction_job as batch_prediction_job_v1,
)
from google.cloud.aiplatform_v1.types import custom_job as custom_job_v1
from google.cloud.aiplatform_v1.types import execution as execution_v1

_LOGGER = base.Logger(__name__)

Expand Down Expand Up @@ -1583,8 +1583,6 @@ def _empty_constructor(
self._logged_web_access_uris = set()

if isinstance(self, CustomJob):
self._experiment = None
self._experiment_run = None
self._enable_autolog = False

return self
Expand Down Expand Up @@ -1633,13 +1631,22 @@ def _block_until_complete(self):

self._log_job_state()

if isinstance(self, CustomJob) and self._experiment_run:
# sync resource before end run
self._experiment_run = aiplatform.ExperimentRun.get(
self._experiment_run.name,
experiment=self._experiment,
)
self._experiment_run.end_run()
if isinstance(self, CustomJob):
# End the experiment run associated with the custom job, if exists.
experiment_run = self._gca_resource.job_spec.experiment_run
if experiment_run:
try:
# sync resource before end run
experiment_run_context = aiplatform.Context(experiment_run)
experiment_run_context.update(
metadata={
metadata_constants._STATE_KEY: gca_execution_compat.Execution.State.COMPLETE.name
}
)
except (ValueError, api_exceptions.GoogleAPIError) as e:
_LOGGER.warning(
f"Failed to end experiment run {experiment_run} due to: {e}"
)

# Error is only populated when the job state is
# JOB_STATE_FAILED or JOB_STATE_CANCELLED.
Expand Down Expand Up @@ -1852,8 +1859,6 @@ def __init__(
),
)

self._experiment = None
self._experiment_run = None
self._enable_autolog = False

@property
Expand Down Expand Up @@ -2510,79 +2515,10 @@ def submit(
if persistent_resource_id:
self._gca_resource.job_spec.persistent_resource_id = persistent_resource_id

# TODO(b/275105711) Update implementation after experiment/run in the proto
if experiment:
# short-term solution to set experiment/experimentRun in SDK
if isinstance(experiment, aiplatform.Experiment):
self._experiment = experiment
# convert the Experiment instance to string to be passed to env
experiment = experiment.name
else:
self._experiment = aiplatform.Experiment.get(experiment_name=experiment)
if not self._experiment:
raise ValueError(
f"Experiment '{experiment}' doesn't exist. "
"Please call aiplatform.init(experiment='my-exp') to create an experiment."
)
elif (
not self._experiment.backing_tensorboard_resource_name
and self._enable_autolog
):
raise ValueError(
f"Experiment '{experiment}' doesn't have a backing tensorboard resource, "
"which is required by the experiment autologging feature. "
"Please call Experiment.assign_backing_tensorboard('my-tb-resource-name')."
)

# if run name is not specified, auto-generate one
if not experiment_run:
experiment_run = (
# TODO(b/223262536)Once display_name is optional this run name
# might be invalid as well.
f"{self._gca_resource.display_name}-{uuid.uuid4().hex[0:5]}"
)

# get or create the experiment run for the job
if isinstance(experiment_run, aiplatform.ExperimentRun):
self._experiment_run = experiment_run
# convert the ExperimentRun instance to string to be passed to env
experiment_run = experiment_run.name
else:
self._experiment_run = aiplatform.ExperimentRun.get(
run_name=experiment_run,
experiment=self._experiment,
)
if not self._experiment_run:
self._experiment_run = aiplatform.ExperimentRun.create(
run_name=experiment_run,
experiment=self._experiment,
)
self._experiment_run.update_state(execution_v1.Execution.State.RUNNING)

worker_pool_specs = self._gca_resource.job_spec.worker_pool_specs
for spec in worker_pool_specs:
if not spec:
continue

if "python_package_spec" in spec:
container_spec = spec.python_package_spec
else:
container_spec = spec.container_spec

experiment_env = [
{
"name": metadata_constants.ENV_EXPERIMENT_KEY,
"value": experiment,
},
{
"name": metadata_constants.ENV_EXPERIMENT_RUN_KEY,
"value": experiment_run,
},
]
if "env" in container_spec:
container_spec.env.extend(experiment_env)
else:
container_spec.env = experiment_env
(
self._gca_resource.job_spec.experiment,
self._gca_resource.job_spec.experiment_run,
) = self._get_experiment_and_run_resource_name(experiment, experiment_run)

_LOGGER.log_create_with_lro(self.__class__)

Expand All @@ -2606,26 +2542,38 @@ def submit(
)
)

if experiment:
custom_job = {
metadata_constants._CUSTOM_JOB_RESOURCE_NAME: self.resource_name,
metadata_constants._CUSTOM_JOB_CONSOLE_URI: self._dashboard_uri(),
}

run_context = self._experiment_run._metadata_node
custom_jobs = run_context._gca_resource.metadata.get(
metadata_constants._CUSTOM_JOB_KEY
)
if custom_jobs:
custom_jobs.append(custom_job)
else:
custom_jobs = [custom_job]
run_context.update({metadata_constants._CUSTOM_JOB_KEY: custom_jobs})

@property
def job_spec(self):
return self._gca_resource.job_spec

@staticmethod
def _get_experiment_and_run_resource_name(
experiment: Optional[Union["aiplatform.Experiment", str]] = None,
experiment_run: Optional[Union["aiplatform.ExperimentRun", str]] = None,
) -> Tuple[str, str]:
"""Helper method to get the experiment and run resource name for the custom job."""
if not experiment:
return None, None

experiment_resource = (
aiplatform.Experiment(experiment)
if isinstance(experiment, str)
else experiment
)

if not experiment_run:
return experiment_resource.resource_name, None

experiment_run_resource = (
aiplatform.ExperimentRun(experiment_run, experiment_resource)
if isinstance(experiment_run, str)
else experiment_run
)
return (
experiment_resource.resource_name,
experiment_run_resource.resource_name,
)


class HyperparameterTuningJob(_RunnableJob, base.PreviewMixin):
"""Vertex AI Hyperparameter Tuning Job."""
Expand Down
93 changes: 77 additions & 16 deletions tests/unit/aiplatform/test_custom_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,10 +83,6 @@
"image_uri": _TEST_TRAINING_CONTAINER_IMAGE,
"command": [],
"args": _TEST_RUN_ARGS,
"env": [
{"name": "AIP_EXPERIMENT_NAME", "value": _TEST_EXPERIMENT},
{"name": "AIP_EXPERIMENT_RUN_NAME", "value": _TEST_EXPERIMENT_RUN},
],
},
}
]
Expand Down Expand Up @@ -160,6 +156,7 @@
_TEST_EXPERIMENT_DESCRIPTION = "test-experiment-description"
_TEST_RUN = "run-1"
_TEST_EXECUTION_ID = f"{_TEST_EXPERIMENT}-{_TEST_RUN}"
_TEST_EXPERIMENT_CONTEXT_NAME = f"{_TEST_PARENT_METADATA}/contexts/{_TEST_EXPERIMENT}"
_TEST_EXPERIMENT_RUN_CONTEXT_NAME = (
f"{_TEST_PARENT_METADATA}/contexts/{_TEST_EXECUTION_ID}"
)
Expand Down Expand Up @@ -203,6 +200,8 @@ def _get_custom_job_proto_with_experiments(state=None, name=None, error=None):
custom_job_proto.name = name
custom_job_proto.state = state
custom_job_proto.error = error
custom_job_proto.job_spec.experiment = _TEST_EXPERIMENT_CONTEXT_NAME
custom_job_proto.job_spec.experiment_run = _TEST_EXPERIMENT_RUN_CONTEXT_NAME
return custom_job_proto


Expand Down Expand Up @@ -255,6 +254,28 @@ def get_custom_job_mock():
yield get_custom_job_mock


@pytest.fixture
def get_custom_job_with_experiments_mock():
with patch.object(
job_service_client.JobServiceClient, "get_custom_job"
) as get_custom_job_mock:
get_custom_job_mock.side_effect = [
_get_custom_job_proto(
name=_TEST_CUSTOM_JOB_NAME,
state=gca_job_state_compat.JobState.JOB_STATE_PENDING,
),
_get_custom_job_proto(
name=_TEST_CUSTOM_JOB_NAME,
state=gca_job_state_compat.JobState.JOB_STATE_RUNNING,
),
_get_custom_job_proto_with_experiments(
name=_TEST_CUSTOM_JOB_NAME,
state=gca_job_state_compat.JobState.JOB_STATE_SUCCEEDED,
),
]
yield get_custom_job_mock


@pytest.fixture
def get_custom_tpu_v5e_job_mock():
with patch.object(
Expand Down Expand Up @@ -455,6 +476,19 @@ def get_experiment_run_run_mock():
yield get_context_mock


@pytest.fixture
def get_experiment_run_not_found_mock():
with patch.object(MetadataServiceClient, "get_context") as get_context_mock:
get_context_mock.side_effect = [
_EXPERIMENT_MOCK,
_EXPERIMENT_RUN_MOCK,
_EXPERIMENT_MOCK,
exceptions.NotFound(""),
]

yield get_context_mock


@pytest.fixture
def update_context_mock():
with patch.object(MetadataServiceClient, "update_context") as update_context_mock:
Expand Down Expand Up @@ -598,7 +632,7 @@ def test_submit_custom_job_with_experiments(
restart_job_on_worker_restart=_TEST_RESTART_JOB_ON_WORKER_RESTART,
create_request_timeout=None,
experiment=_TEST_EXPERIMENT,
experiment_run=_TEST_EXPERIMENT_RUN,
experiment_run=_TEST_RUN,
disable_retries=_TEST_DISABLE_RETRIES,
)

Expand All @@ -616,17 +650,6 @@ def test_submit_custom_job_with_experiments(
timeout=None,
)

expected_run_context = copy.deepcopy(_EXPERIMENT_RUN_MOCK)
expected_run_context.metadata[constants._CUSTOM_JOB_KEY] = [
{
constants._CUSTOM_JOB_RESOURCE_NAME: _TEST_CUSTOM_JOB_NAME,
constants._CUSTOM_JOB_CONSOLE_URI: job._dashboard_uri(),
}
]
update_context_mock.assert_called_with(
context=expected_run_context,
)

@pytest.mark.parametrize("sync", [True, False])
@mock.patch.object(jobs, "_JOB_WAIT_TIME", 1)
@mock.patch.object(jobs, "_LOG_WAIT_TIME", 1)
Expand Down Expand Up @@ -714,6 +737,44 @@ def test_create_custom_job_with_timeout_not_explicitly_set(
timeout=None,
)

@pytest.mark.usefixtures(
"create_custom_job_mock",
"get_custom_job_with_experiments_mock",
"get_experiment_run_not_found_mock",
"get_tensorboard_run_artifact_not_found_mock",
)
def test_run_custom_job_with_experiment_run_warning(self, caplog):

aiplatform.init(
project=_TEST_PROJECT,
location=_TEST_LOCATION,
staging_bucket=_TEST_STAGING_BUCKET,
encryption_spec_key_name=_TEST_DEFAULT_ENCRYPTION_KEY_NAME,
)

job = aiplatform.CustomJob(
display_name=_TEST_DISPLAY_NAME,
worker_pool_specs=_TEST_WORKER_POOL_SPEC,
base_output_dir=_TEST_BASE_OUTPUT_DIR,
labels=_TEST_LABELS,
)

job.run(
service_account=_TEST_SERVICE_ACCOUNT,
network=_TEST_NETWORK,
timeout=_TEST_TIMEOUT,
restart_job_on_worker_restart=_TEST_RESTART_JOB_ON_WORKER_RESTART,
create_request_timeout=None,
experiment=_TEST_EXPERIMENT,
experiment_run=_TEST_RUN,
disable_retries=_TEST_DISABLE_RETRIES,
)

assert (
f"Failed to end experiment run {_TEST_EXPERIMENT_RUN_CONTEXT_NAME} due to:"
in caplog.text
)

@pytest.mark.parametrize("sync", [True, False])
def test_run_custom_job_with_fail_raises(
self, create_custom_job_mock, get_custom_job_mock_with_fail, sync
Expand Down

0 comments on commit 555ead7

Please sign in to comment.