Skip to content

Commit

Permalink
feat: Enable Ray Job submission without VPC peering
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 641037130
  • Loading branch information
yinghsienwu authored and Copybara-Service committed Jun 6, 2024
1 parent 6592042 commit 37875b5
Show file tree
Hide file tree
Showing 2 changed files with 56 additions and 3 deletions.
23 changes: 20 additions & 3 deletions google/cloud/aiplatform/vertex_ray/dashboard_sdk.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,10 +69,27 @@ def get_job_submission_client_cluster_info(
"RAY_HEAD_NODE_INTERNAL_IP", None
)
if head_address is None:
raise RuntimeError(
"[Ray on Vertex AI]: Unable to obtain a response from the backend."
# No peering. Try to get the dashboard address.
dashboard_address = response.resource_runtime.access_uris.get(
"RAY_DASHBOARD_URI", None
)

if dashboard_address is None:
raise RuntimeError(
"[Ray on Vertex AI]: Unable to obtain a response from the backend."
)
if _validation_utils.valid_dashboard_address(dashboard_address):
bearer_token = _validation_utils.get_bearer_token()
if kwargs.get("headers", None) is None:
kwargs["headers"] = {
"Content-Type": "application/json",
"Authorization": "Bearer {}".format(bearer_token),
}
return oss_dashboard_sdk.get_job_submission_client_cluster_info(
address=dashboard_address,
_use_tls=True,
*args,
**kwargs,
)
# Assume that head node internal IP in a form of xxx.xxx.xxx.xxx:10001.
# Ray-on-Vertex cluster serves the Dashboard at port 8888 instead of
# the default 8251.
Expand Down
36 changes: 36 additions & 0 deletions tests/unit/vertex_ray/test_dashboard_sdk.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,18 @@ def get_persistent_resource_status_running_mock():
yield get_persistent_resource


@pytest.fixture
def get_persistent_resource_status_running_byosa_public_mock():
# Cluster with BYOSA and no peering
with mock.patch.object(
vertex_ray.util._gapic_utils, "get_persistent_resource"
) as get_persistent_resource:
get_persistent_resource.return_value = (
tc.ClusterConstants.TEST_RESPONSE_RUNNING_1_POOL_BYOSA
)
yield get_persistent_resource


@pytest.fixture
def get_bearer_token_mock():
with mock.patch.object(
Expand Down Expand Up @@ -112,3 +124,27 @@ def test_job_submission_client_cluster_info_with_dashboard_address(
_use_tls=True,
headers=tc.ClusterConstants.TEST_HEADERS,
)

@pytest.mark.usefixtures(
"get_persistent_resource_status_running_byosa_public_mock", "google_auth_mock"
)
def test_job_submission_client_cluster_info_with_cluster_name_byosa_public(
self,
ray_get_job_submission_client_cluster_info_mock,
get_bearer_token_mock,
get_project_number_mock,
):
aiplatform.init(project=tc.ProjectConstants.TEST_GCP_PROJECT_ID)

vertex_ray.get_job_submission_client_cluster_info(
tc.ClusterConstants.TEST_VERTEX_RAY_PR_ID
)
get_project_number_mock.assert_called_once_with(
name="projects/{}".format(tc.ProjectConstants.TEST_GCP_PROJECT_ID)
)
get_bearer_token_mock.assert_called_once_with()
ray_get_job_submission_client_cluster_info_mock.assert_called_once_with(
address=tc.ClusterConstants.TEST_VERTEX_RAY_DASHBOARD_ADDRESS,
_use_tls=True,
headers=tc.ClusterConstants.TEST_HEADERS,
)

0 comments on commit 37875b5

Please sign in to comment.