Skip to content

Commit

Permalink
feat: Support experiment autologging when using persistent cluster as…
Browse files Browse the repository at this point in the history
… executor

PiperOrigin-RevId: 574306937
  • Loading branch information
yinghsienwu authored and Copybara-Service committed Oct 18, 2023
1 parent a9d7632 commit c19b6c3
Show file tree
Hide file tree
Showing 7 changed files with 244 additions and 27 deletions.
30 changes: 27 additions & 3 deletions tests/unit/vertexai/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@
from google.cloud.aiplatform_v1beta1.types.persistent_resource import (
PersistentResource,
ResourcePool,
ResourceRuntimeSpec,
ServiceAccountSpec,
)


Expand All @@ -54,6 +56,7 @@
_TEST_DISPLAY_NAME = f"{_TEST_PARENT}/customJobs/12345"
_TEST_BUCKET_NAME = "gs://test_bucket"
_TEST_BASE_OUTPUT_DIR = f"{_TEST_BUCKET_NAME}/test_base_output_dir"
_TEST_SERVICE_ACCOUNT = f"{_TEST_PROJECT_NUMBER}-compute@developer.gserviceaccount.com"

_TEST_INPUTS = [
"--arg_0=string_val_0",
Expand Down Expand Up @@ -86,7 +89,9 @@
labels={"trained_by_vertex_ai": "true"},
)

_TEST_REQUEST_RUNNING_DEFAULT = PersistentResource()
_TEST_REQUEST_RUNNING_DEFAULT = PersistentResource(
resource_runtime_spec=ResourceRuntimeSpec(service_account_spec=ServiceAccountSpec())
)
resource_pool = ResourcePool()
resource_pool.machine_spec.machine_type = "n1-standard-4"
resource_pool.replica_count = 1
Expand All @@ -95,8 +100,15 @@
_TEST_REQUEST_RUNNING_DEFAULT.resource_pools = [resource_pool]


_TEST_PERSISTENT_RESOURCE_RUNNING = PersistentResource()
_TEST_PERSISTENT_RESOURCE_RUNNING.state = "RUNNING"
_TEST_PERSISTENT_RESOURCE_RUNNING = PersistentResource(state="RUNNING")
_TEST_PERSISTENT_RESOURCE_SERVICE_ACCOUNT_RUNNING = PersistentResource(
state="RUNNING",
resource_runtime_spec=ResourceRuntimeSpec(
service_account_spec=ServiceAccountSpec(
enable_custom_service_account=True, service_account=_TEST_SERVICE_ACCOUNT
)
),
)


@pytest.fixture(scope="module")
Expand Down Expand Up @@ -284,6 +296,18 @@ def persistent_resource_running_mock():
yield persistent_resource_running_mock


@pytest.fixture
def persistent_resource_service_account_running_mock():
with mock.patch.object(
PersistentResourceServiceClient,
"get_persistent_resource",
) as persistent_resource_service_account_running_mock:
persistent_resource_service_account_running_mock.return_value = (
_TEST_PERSISTENT_RESOURCE_SERVICE_ACCOUNT_RUNNING
)
yield persistent_resource_service_account_running_mock


@pytest.fixture
def persistent_resource_exception_mock():
with mock.patch.object(
Expand Down
8 changes: 8 additions & 0 deletions tests/unit/vertexai/test_persistent_resource_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@
)
from google.cloud.aiplatform_v1beta1.types.persistent_resource import (
ResourcePool,
ResourceRuntimeSpec,
ServiceAccountSpec,
)
from vertexai.preview._workflow.executor import (
persistent_resource_util,
Expand Down Expand Up @@ -75,8 +77,14 @@
)
_TEST_REQUEST_RUNNING_DEFAULT = PersistentResource(
resource_pools=[resource_pool_0],
resource_runtime_spec=ResourceRuntimeSpec(
service_account_spec=ServiceAccountSpec(enable_custom_service_account=False),
),
)
_TEST_REQUEST_RUNNING_CUSTOM = PersistentResource(
resource_runtime_spec=ResourceRuntimeSpec(
service_account_spec=ServiceAccountSpec(enable_custom_service_account=False),
),
resource_pools=[resource_pool_0, resource_pool_1],
)

Expand Down
140 changes: 134 additions & 6 deletions tests/unit/vertexai/test_remote_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@

# vertexai constants
_TEST_PROJECT = "test-project"
_TEST_PROJECT_NUMBER = 123
_TEST_PROJECT_NUMBER = 12345678
_TEST_LOCATION = "us-central1"
_TEST_PARENT = f"projects/{_TEST_PROJECT}/locations/{_TEST_LOCATION}"
_TEST_BUCKET_NAME = "gs://test-bucket"
Expand All @@ -88,6 +88,7 @@
_TEST_REMOTE_JOB_BASE_PATH = os.path.join(_TEST_BUCKET_NAME, _TEST_REMOTE_JOB_NAME)
_TEST_EXPERIMENT = "test-experiment"
_TEST_EXPERIMENT_RUN = "test-experiment-run"
_TEST_SERVICE_ACCOUNT = f"{_TEST_PROJECT_NUMBER}-compute@developer.gserviceaccount.com"

# dataset constants
dataset = load_iris()
Expand Down Expand Up @@ -269,6 +270,20 @@
],
)

_TEST_PERSISTENT_RESOURCE_CONFIG_SERVICE_ACCOUNT = configs.PersistentResourceConfig(
name=_TEST_PERSISTENT_RESOURCE_ID,
resource_pools=[
remote_specs.ResourcePool(
replica_count=_TEST_REPLICA_COUNT,
),
remote_specs.ResourcePool(
machine_type="n1-standard-8",
replica_count=2,
),
],
service_account=_TEST_SERVICE_ACCOUNT,
)

_TEST_PERSISTENT_RESOURCE_CONFIG_DISABLE = configs.PersistentResourceConfig(
name=_TEST_PERSISTENT_RESOURCE_ID,
resource_pools=[
Expand Down Expand Up @@ -1583,7 +1598,7 @@ def test_remote_training_keras_distributed_no_cuda_no_worker_pool_specs(
@pytest.mark.xfail(
sys.version_info.minor >= 8,
raises=ValueError,
reason="Flaky in python 3.8, 3.10, 3.11",
reason="Flaky in python >=3.8",
)
@pytest.mark.usefixtures(
"list_default_tensorboard_mock",
Expand Down Expand Up @@ -1667,7 +1682,7 @@ def test_remote_training_sklearn_with_experiment(
@pytest.mark.xfail(
sys.version_info.minor >= 8,
raises=ValueError,
reason="Flaky in python 3.8, 3.10, 3.11",
reason="Flaky in python >=3.8",
)
@pytest.mark.usefixtures(
"list_default_tensorboard_mock",
Expand Down Expand Up @@ -1856,6 +1871,27 @@ def test_remote_training_sklearn_with_persistent_cluster(
model.score(_X_TEST, _Y_TEST)

@pytest.mark.usefixtures(
"mock_timestamped_unique_name",
"mock_get_custom_job",
"mock_autolog_disabled",
"persistent_resource_running_mock",
)
def test_initialize_existing_persistent_resource_service_account_mismatch(self):
vertexai.init(
project=_TEST_PROJECT,
location=_TEST_LOCATION,
staging_bucket=_TEST_BUCKET_NAME,
)
with pytest.raises(ValueError) as e:
vertexai.preview.init(
cluster=_TEST_PERSISTENT_RESOURCE_CONFIG_SERVICE_ACCOUNT
)
e.match(
regexp=r"Expect the existing cluster was created with the service account "
)

@pytest.mark.usefixtures(
"mock_get_project_number",
"list_default_tensorboard_mock",
"mock_get_experiment_run",
"mock_get_metadata_store",
Expand All @@ -1865,7 +1901,7 @@ def test_remote_training_sklearn_with_persistent_cluster(
"mock_autolog_enabled",
"persistent_resource_running_mock",
)
def test_remote_training_sklearn_with_persistent_cluster_and_experiment_error(
def test_remote_training_sklearn_with_persistent_cluster_no_service_account_and_experiment_error(
self,
):
vertexai.init(
Expand All @@ -1884,9 +1920,101 @@ def test_remote_training_sklearn_with_persistent_cluster_and_experiment_error(
with pytest.raises(ValueError) as e:
model.fit.vertex.remote_config.service_account = "GCE"
model.fit(_X_TRAIN, _Y_TRAIN)
e.match(
regexp=r"Persistent cluster currently does not support custom service account."
e.match(regexp=r"The service account for autologging")

# TODO(b/300116902) Remove this once we find better solution.
@pytest.mark.xfail(
sys.version_info.minor >= 8,
raises=ValueError,
reason="Flaky in python >=3.8",
)
@pytest.mark.usefixtures(
"mock_get_project_number",
"list_default_tensorboard_mock",
"mock_get_experiment_run",
"mock_get_metadata_store",
"get_artifact_not_found_mock",
"update_context_mock",
"aiplatform_autolog_mock",
"mock_autolog_enabled",
"persistent_resource_service_account_running_mock",
"mock_timestamped_unique_name",
"mock_get_custom_job",
)
def test_remote_training_sklearn_with_persistent_cluster_and_experiment_autologging(
self,
mock_any_serializer_sklearn,
mock_create_custom_job,
):
vertexai.init(
project=_TEST_PROJECT,
location=_TEST_LOCATION,
staging_bucket=_TEST_BUCKET_NAME,
experiment=_TEST_EXPERIMENT,
)
vertexai.preview.init(
remote=True,
autolog=True,
cluster=_TEST_PERSISTENT_RESOURCE_CONFIG_SERVICE_ACCOUNT,
)

vertexai.preview.start_run(_TEST_EXPERIMENT_RUN, resume=True)

LogisticRegression = vertexai.preview.remote(_logistic.LogisticRegression)
model = LogisticRegression()

model.fit.vertex.remote_config.service_account = _TEST_SERVICE_ACCOUNT

model.fit(_X_TRAIN, _Y_TRAIN)

# check that model is serialized correctly
mock_any_serializer_sklearn.return_value.serialize.assert_any_call(
to_serialize=model,
gcs_path=os.path.join(_TEST_REMOTE_JOB_BASE_PATH, "input/input_estimator"),
)

# check that args are serialized correctly
mock_any_serializer_sklearn.return_value.serialize.assert_any_call(
to_serialize=_X_TRAIN,
gcs_path=os.path.join(_TEST_REMOTE_JOB_BASE_PATH, "input/X"),
)
mock_any_serializer_sklearn.return_value.serialize.assert_any_call(
to_serialize=_Y_TRAIN,
gcs_path=os.path.join(_TEST_REMOTE_JOB_BASE_PATH, "input/y"),
)

# ckeck that CustomJob is created correctly
expected_custom_job = _get_custom_job_proto(
service_account=_TEST_SERVICE_ACCOUNT,
experiment=_TEST_EXPERIMENT,
experiment_run=_TEST_EXPERIMENT_RUN,
autolog_enabled=True,
persistent_resource_id=_TEST_PERSISTENT_RESOURCE_ID,
)
mock_create_custom_job.assert_called_once_with(
parent=_TEST_PARENT,
custom_job=expected_custom_job,
timeout=None,
)

# check that trained model is deserialized correctly
mock_any_serializer_sklearn.return_value.deserialize.assert_has_calls(
[
mock.call(
os.path.join(_TEST_REMOTE_JOB_BASE_PATH, "output/output_estimator")
),
mock.call(
os.path.join(_TEST_REMOTE_JOB_BASE_PATH, "output/output_data")
),
]
)

# change to `vertexai.preview.init(remote=False)` to use local prediction
vertexai.preview.init(remote=False)

# check that local model is updated in place
# `model.score` raises NotFittedError if the model is not updated
model.score(_X_TEST, _Y_TEST)

@pytest.mark.usefixtures(
"mock_timestamped_unique_name",
Expand Down
42 changes: 41 additions & 1 deletion vertexai/preview/_workflow/executor/persistent_resource_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@
)
from google.cloud.aiplatform_v1beta1.types.persistent_resource import (
ResourcePool,
ResourceRuntimeSpec,
ServiceAccountSpec,
)
from google.cloud.aiplatform_v1beta1.types.persistent_resource_service import (
GetPersistentResourceRequest,
Expand Down Expand Up @@ -61,18 +63,28 @@ def _create_persistent_resource_client(location: Optional[str] = "us-central1"):
)


def check_persistent_resource(cluster_resource_name: str) -> bool:
def cluster_resource_name(project: str, location: str, name: str) -> str:
"""Helper method to get persistent resource name."""
client = _create_persistent_resource_client(location)
return client.persistent_resource_path(project, location, name)


def check_persistent_resource(
cluster_resource_name: str, service_account: Optional[str] = None
) -> bool:
"""Helper method to check if a persistent resource exists or not.
Args:
cluster_resource_name: Persistent Resource name. Has the form:
``projects/my-project/locations/my-region/persistentResource/cluster-name``.
service_account: Service account.
Returns:
True if a Persistent Resource exists.
Raises:
ValueError: if existing cluster is not RUNNING.
ValueError: if service account is specified but mismatch with existing cluster.
"""
# Parse resource name to get the location.
locataion = cluster_resource_name.split("/")[3]
Expand All @@ -91,6 +103,24 @@ def check_persistent_resource(cluster_resource_name: str) -> bool:
cluster_resource_name,
"` isn't running, please specify a different cluster_name.",
)
# Check if service account of this existing persistent resource matches initialized one.
existing_cluster_service_account = (
response.resource_runtime_spec.service_account_spec.service_account
if response.resource_runtime_spec.service_account_spec
else None
)

if (
service_account is not None
and existing_cluster_service_account != service_account
):
raise ValueError(
"Expect the existing cluster was created with the service account `",
service_account,
"`, but got `",
existing_cluster_service_account,
"` , please ensure service account is consistent with the initialization.",
)
return True


Expand Down Expand Up @@ -185,6 +215,7 @@ def _get_persistent_resource(cluster_resource_name: str):
def create_persistent_resource(
cluster_resource_name: str,
resource_pools: Optional[List[remote_specs.ResourcePool]] = None,
service_account: Optional[str] = None,
):
"""Create a persistent resource."""
locataion = cluster_resource_name.split("/")[3]
Expand All @@ -209,6 +240,15 @@ def create_persistent_resource(

persistent_resource = PersistentResource(resource_pools=pools)

enable_custom_service_account = True if service_account is not None else False

resource_runtime_spec = ResourceRuntimeSpec(
service_account_spec=ServiceAccountSpec(
enable_custom_service_account=enable_custom_service_account,
service_account=service_account,
),
)
persistent_resource.resource_runtime_spec = resource_runtime_spec
request = persistent_resource_service.CreatePersistentResourceRequest(
parent=parent,
persistent_resource=persistent_resource,
Expand Down
Loading

0 comments on commit c19b6c3

Please sign in to comment.