Skip to content

Commit

Permalink
feat: Add an arg to turn off Ray metrics collection during cluster cr…
Browse files Browse the repository at this point in the history
…eation

PiperOrigin-RevId: 617612703
  • Loading branch information
yinghsienwu authored and Copybara-Service committed Mar 20, 2024
1 parent e51c977 commit e33d11f
Show file tree
Hide file tree
Showing 3 changed files with 60 additions and 83 deletions.
10 changes: 8 additions & 2 deletions google/cloud/aiplatform/preview/vertex_ray/cluster_init.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from google.cloud.aiplatform_v1beta1.types.persistent_resource import (
PersistentResource,
RaySpec,
RayMetricSpec,
ResourcePool,
ResourceRuntimeSpec,
)
Expand All @@ -49,6 +50,7 @@ def create_ray_cluster(
cluster_name: Optional[str] = None,
worker_node_types: Optional[List[resources.Resources]] = None,
custom_images: Optional[resources.NodeImages] = None,
enable_metrics_collection: Optional[bool] = True,
labels: Optional[Dict[str, str]] = None,
) -> str:
"""Create a ray cluster on the Vertex AI.
Expand Down Expand Up @@ -107,6 +109,7 @@ def create_ray_cluster(
has a specific custom image, use `Resources.custom_image` for
head/worker_node_type(s). Note that configuring `Resources.custom_image`
will override `custom_images` here. Allowlist only.
enable_metrics_collection: Enable Ray metrics collection for visualization.
labels:
The labels with user-defined metadata to organize Ray cluster.
Expand Down Expand Up @@ -244,8 +247,11 @@ def create_ray_cluster(
i += 1

resource_pools = [resource_pool_0] + worker_pools

ray_spec = RaySpec(resource_pool_images=resource_pool_images)
disabled = not enable_metrics_collection
ray_metric_spec = RayMetricSpec(disabled=disabled)
ray_spec = RaySpec(
resource_pool_images=resource_pool_images, ray_metric_spec=ray_metric_spec
)
resource_runtime_spec = ResourceRuntimeSpec(ray_spec=ray_spec)
persistent_resource = PersistentResource(
resource_pools=resource_pools,
Expand Down
93 changes: 21 additions & 72 deletions tests/unit/vertex_ray/test_cluster_init.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@


# -*- coding: utf-8 -*-
# TODO(b/328684671)
_EXPECTED_MASK = field_mask_pb2.FieldMask(paths=["resource_pools.replica_count"])

# for manual scaling
Expand Down Expand Up @@ -241,6 +240,22 @@ def update_persistent_resource_2_pools_mock():
yield update_persistent_resource_2_pools_mock


def cluster_eq(returned_cluster, expected_cluster):
assert vars(returned_cluster.head_node_type) == vars(
expected_cluster.head_node_type
)
assert vars(returned_cluster.worker_node_types[0]) == vars(
expected_cluster.worker_node_types[0]
)
assert (
returned_cluster.cluster_resource_name == expected_cluster.cluster_resource_name
)
assert returned_cluster.python_version == expected_cluster.python_version
assert returned_cluster.ray_version == expected_cluster.ray_version
assert returned_cluster.network == expected_cluster.network
assert returned_cluster.state == expected_cluster.state


@pytest.mark.usefixtures("google_auth_mock", "get_project_number_mock")
class TestClusterManagement:
def setup_method(self):
Expand Down Expand Up @@ -315,6 +330,7 @@ def test_create_ray_cluster_1_pool_gpu_with_labels_success(
network=tc.ProjectConstants.TEST_VPC_NETWORK,
cluster_name=tc.ClusterConstants.TEST_VERTEX_RAY_PR_ID,
labels=tc.ClusterConstants.TEST_LABELS,
enable_metrics_collection=False,
)

assert tc.ClusterConstants.TEST_VERTEX_RAY_PR_ADDRESS == cluster_name
Expand Down Expand Up @@ -465,21 +481,7 @@ def test_get_ray_cluster_success(self, get_persistent_resource_1_pool_mock):
)

get_persistent_resource_1_pool_mock.assert_called_once()

assert vars(cluster.head_node_type) == vars(
tc.ClusterConstants.TEST_CLUSTER.head_node_type
)
assert vars(cluster.worker_node_types[0]) == vars(
tc.ClusterConstants.TEST_CLUSTER.worker_node_types[0]
)
assert (
cluster.cluster_resource_name
== tc.ClusterConstants.TEST_CLUSTER.cluster_resource_name
)
assert cluster.python_version == tc.ClusterConstants.TEST_CLUSTER.python_version
assert cluster.ray_version == tc.ClusterConstants.TEST_CLUSTER.ray_version
assert cluster.network == tc.ClusterConstants.TEST_CLUSTER.network
assert cluster.state == tc.ClusterConstants.TEST_CLUSTER.state
cluster_eq(cluster, tc.ClusterConstants.TEST_CLUSTER)

def test_get_ray_cluster_with_custom_image_success(
self, get_persistent_resource_2_pools_custom_image_mock
Expand All @@ -489,27 +491,7 @@ def test_get_ray_cluster_with_custom_image_success(
)

get_persistent_resource_2_pools_custom_image_mock.assert_called_once()

assert vars(cluster.head_node_type) == vars(
tc.ClusterConstants.TEST_CLUSTER_CUSTOM_IMAGE.head_node_type
)
assert vars(cluster.worker_node_types[0]) == vars(
tc.ClusterConstants.TEST_CLUSTER_CUSTOM_IMAGE.worker_node_types[0]
)
assert (
cluster.cluster_resource_name
== tc.ClusterConstants.TEST_CLUSTER_CUSTOM_IMAGE.cluster_resource_name
)
assert (
cluster.python_version
== tc.ClusterConstants.TEST_CLUSTER_CUSTOM_IMAGE.python_version
)
assert (
cluster.ray_version
== tc.ClusterConstants.TEST_CLUSTER_CUSTOM_IMAGE.ray_version
)
assert cluster.network == tc.ClusterConstants.TEST_CLUSTER_CUSTOM_IMAGE.network
assert cluster.state == tc.ClusterConstants.TEST_CLUSTER_CUSTOM_IMAGE.state
cluster_eq(cluster, tc.ClusterConstants.TEST_CLUSTER_CUSTOM_IMAGE)

@pytest.mark.usefixtures("get_persistent_resource_exception_mock")
def test_get_ray_cluster_error(self):
Expand All @@ -526,42 +508,9 @@ def test_list_ray_clusters_success(self, list_persistent_resources_mock):
list_persistent_resources_mock.assert_called_once()

# first ray cluster
assert vars(clusters[0].head_node_type) == vars(
tc.ClusterConstants.TEST_CLUSTER.head_node_type
)
assert vars(clusters[0].worker_node_types[0]) == vars(
tc.ClusterConstants.TEST_CLUSTER.worker_node_types[0]
)
assert (
clusters[0].cluster_resource_name
== tc.ClusterConstants.TEST_CLUSTER.cluster_resource_name
)
assert (
clusters[0].python_version
== tc.ClusterConstants.TEST_CLUSTER.python_version
)
assert clusters[0].ray_version == tc.ClusterConstants.TEST_CLUSTER.ray_version
assert clusters[0].network == tc.ClusterConstants.TEST_CLUSTER.network
assert clusters[0].state == tc.ClusterConstants.TEST_CLUSTER.state

cluster_eq(clusters[0], tc.ClusterConstants.TEST_CLUSTER)
# second ray cluster
assert vars(clusters[1].head_node_type) == vars(
tc.ClusterConstants.TEST_CLUSTER_2.head_node_type
)
assert vars(clusters[1].worker_node_types[0]) == vars(
tc.ClusterConstants.TEST_CLUSTER_2.worker_node_types[0]
)
assert (
clusters[1].cluster_resource_name
== tc.ClusterConstants.TEST_CLUSTER_2.cluster_resource_name
)
assert (
clusters[1].python_version
== tc.ClusterConstants.TEST_CLUSTER_2.python_version
)
assert clusters[1].ray_version == tc.ClusterConstants.TEST_CLUSTER_2.ray_version
assert clusters[1].network == tc.ClusterConstants.TEST_CLUSTER_2.network
assert clusters[1].state == tc.ClusterConstants.TEST_CLUSTER_2.state
cluster_eq(clusters[1], tc.ClusterConstants.TEST_CLUSTER_2)

def test_list_ray_clusters_initialized_success(
self, get_project_number_mock, list_persistent_resources_mock
Expand Down
40 changes: 31 additions & 9 deletions tests/unit/vertex_ray/test_constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,9 @@
PersistentResource,
)
from google.cloud.aiplatform_v1beta1.types.persistent_resource import RaySpec
from google.cloud.aiplatform_v1beta1.types.persistent_resource import (
RayMetricSpec,
)
from google.cloud.aiplatform_v1beta1.types.persistent_resource import (
ResourcePool,
)
Expand Down Expand Up @@ -116,22 +119,31 @@ class ClusterConstants:
TEST_REQUEST_RUNNING_1_POOL = PersistentResource(
resource_pools=[TEST_RESOURCE_POOL_0],
resource_runtime_spec=ResourceRuntimeSpec(
ray_spec=RaySpec(resource_pool_images={"head-node": TEST_GPU_IMAGE}),
ray_spec=RaySpec(
resource_pool_images={"head-node": TEST_GPU_IMAGE},
ray_metric_spec=RayMetricSpec(disabled=False),
),
),
network=ProjectConstants.TEST_VPC_NETWORK,
)
TEST_REQUEST_RUNNING_1_POOL_WITH_LABELS = PersistentResource(
resource_pools=[TEST_RESOURCE_POOL_0],
resource_runtime_spec=ResourceRuntimeSpec(
ray_spec=RaySpec(resource_pool_images={"head-node": TEST_GPU_IMAGE}),
ray_spec=RaySpec(
resource_pool_images={"head-node": TEST_GPU_IMAGE},
ray_metric_spec=RayMetricSpec(disabled=True),
),
),
network=ProjectConstants.TEST_VPC_NETWORK,
labels=TEST_LABELS,
)
TEST_REQUEST_RUNNING_1_POOL_CUSTOM_IMAGES = PersistentResource(
resource_pools=[TEST_RESOURCE_POOL_0],
resource_runtime_spec=ResourceRuntimeSpec(
ray_spec=RaySpec(resource_pool_images={"head-node": TEST_CUSTOM_IMAGE}),
ray_spec=RaySpec(
resource_pool_images={"head-node": TEST_CUSTOM_IMAGE},
ray_metric_spec=RayMetricSpec(disabled=False),
),
),
network=ProjectConstants.TEST_VPC_NETWORK,
)
Expand All @@ -140,7 +152,10 @@ class ClusterConstants:
name=TEST_VERTEX_RAY_PR_ADDRESS,
resource_pools=[TEST_RESOURCE_POOL_0],
resource_runtime_spec=ResourceRuntimeSpec(
ray_spec=RaySpec(resource_pool_images={"head-node": TEST_GPU_IMAGE}),
ray_spec=RaySpec(
resource_pool_images={"head-node": TEST_GPU_IMAGE},
ray_metric_spec=RayMetricSpec(disabled=False),
),
),
network=ProjectConstants.TEST_VPC_NETWORK,
resource_runtime=ResourceRuntime(
Expand All @@ -156,7 +171,10 @@ class ClusterConstants:
name=TEST_VERTEX_RAY_PR_ADDRESS,
resource_pools=[TEST_RESOURCE_POOL_0],
resource_runtime_spec=ResourceRuntimeSpec(
ray_spec=RaySpec(resource_pool_images={"head-node": TEST_CUSTOM_IMAGE}),
ray_spec=RaySpec(
resource_pool_images={"head-node": TEST_CUSTOM_IMAGE},
ray_metric_spec=RayMetricSpec(disabled=False),
),
),
network=ProjectConstants.TEST_VPC_NETWORK,
resource_runtime=ResourceRuntime(
Expand Down Expand Up @@ -218,7 +236,8 @@ class ClusterConstants:
resource_pool_images={
"head-node": TEST_CPU_IMAGE,
"worker-pool1": TEST_GPU_IMAGE,
}
},
ray_metric_spec=RayMetricSpec(disabled=False),
),
),
network=ProjectConstants.TEST_VPC_NETWORK,
Expand All @@ -230,7 +249,8 @@ class ClusterConstants:
resource_pool_images={
"head-node": TEST_CUSTOM_IMAGE,
"worker-pool1": TEST_CUSTOM_IMAGE,
}
},
ray_metric_spec=RayMetricSpec(disabled=False),
),
),
network=ProjectConstants.TEST_VPC_NETWORK,
Expand All @@ -243,7 +263,8 @@ class ClusterConstants:
resource_pool_images={
"head-node": TEST_CPU_IMAGE,
"worker-pool1": TEST_GPU_IMAGE,
}
},
ray_metric_spec=RayMetricSpec(disabled=False),
),
),
network=ProjectConstants.TEST_VPC_NETWORK,
Expand All @@ -263,7 +284,8 @@ class ClusterConstants:
resource_pool_images={
"head-node": TEST_CUSTOM_IMAGE,
"worker-pool1": TEST_CUSTOM_IMAGE,
}
},
ray_metric_spec=RayMetricSpec(disabled=False),
),
),
network=ProjectConstants.TEST_VPC_NETWORK,
Expand Down

0 comments on commit e33d11f

Please sign in to comment.