Skip to content

Commit

Permalink
fix: support VPC and BYOSA case in Ray on Vertex JobSubmissionClient …
Browse files Browse the repository at this point in the history
…using cluster resource name

PiperOrigin-RevId: 642446002
  • Loading branch information
yinghsienwu authored and Copybara-Service committed Jun 12, 2024
1 parent 17c59c4 commit 662d039
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 45 deletions.
2 changes: 1 addition & 1 deletion google/cloud/aiplatform/vertex_ray/client_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ def __init__(self, address: Optional[str]) -> None:
address,
" failed to start Head node properly because custom service"
" account isn't supported in peered VPC network. Use public"
" endpoint instead (createa a cluster withought specifying"
" endpoint instead (createa a cluster without specifying"
" VPC network).",
)
else:
Expand Down
60 changes: 19 additions & 41 deletions google/cloud/aiplatform/vertex_ray/dashboard_sdk.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,55 +46,33 @@ def get_job_submission_client_cluster_info(
Raises:
RuntimeError if head_address is None.
"""
# If passing the dashboard uri, programmatically get headers
if _validation_utils.valid_dashboard_address(address):
bearer_token = _validation_utils.get_bearer_token()
if kwargs.get("headers", None) is None:
kwargs["headers"] = {
"Content-Type": "application/json",
"Authorization": "Bearer {}".format(bearer_token),
}
return oss_dashboard_sdk.get_job_submission_client_cluster_info(
address=address,
_use_tls=True,
*args,
**kwargs,
)
address = _validation_utils.maybe_reconstruct_resource_name(address)
_validation_utils.valid_resource_name(address)
dashboard_address = address
else:
address = _validation_utils.maybe_reconstruct_resource_name(address)
_validation_utils.valid_resource_name(address)

resource_name = address
response = _gapic_utils.get_persistent_resource(resource_name)

resource_name = address
response = _gapic_utils.get_persistent_resource(resource_name)
head_address = response.resource_runtime.access_uris.get(
"RAY_HEAD_NODE_INTERNAL_IP", None
)
if head_address is None:
# No peering. Try to get the dashboard address.
dashboard_address = response.resource_runtime.access_uris.get(
"RAY_DASHBOARD_URI", None
)

if dashboard_address is None:
raise RuntimeError(
"[Ray on Vertex AI]: Unable to obtain a response from the backend."
)
if _validation_utils.valid_dashboard_address(dashboard_address):
bearer_token = _validation_utils.get_bearer_token()
if kwargs.get("headers", None) is None:
kwargs["headers"] = {
"Content-Type": "application/json",
"Authorization": "Bearer {}".format(bearer_token),
}
return oss_dashboard_sdk.get_job_submission_client_cluster_info(
address=dashboard_address,
_use_tls=True,
*args,
**kwargs,
)
# Assume that head node internal IP in a form of xxx.xxx.xxx.xxx:10001.
# Ray-on-Vertex cluster serves the Dashboard at port 8888 instead of
# the default 8251.
head_address = ":".join([head_address.split(":")[0], "8888"])

# If passing the dashboard uri, programmatically get headers
bearer_token = _validation_utils.get_bearer_token()
if kwargs.get("headers", None) is None:
kwargs["headers"] = {
"Content-Type": "application/json",
"Authorization": "Bearer {}".format(bearer_token),
}
return oss_dashboard_sdk.get_job_submission_client_cluster_info(
address=head_address, *args, **kwargs
address=dashboard_address,
_use_tls=True,
*args,
**kwargs,
)
16 changes: 13 additions & 3 deletions tests/unit/vertex_ray/test_dashboard_sdk.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,16 +73,22 @@ def setup_method(self):
def teardown_method(self):
aiplatform.initializer.global_pool.shutdown(wait=True)

@pytest.mark.usefixtures("get_persistent_resource_status_running_mock")
@pytest.mark.usefixtures(
"get_persistent_resource_status_running_mock", "google_auth_mock"
)
def test_job_submission_client_cluster_info_with_full_resource_name(
self,
ray_get_job_submission_client_cluster_info_mock,
get_bearer_token_mock,
):
vertex_ray.get_job_submission_client_cluster_info(
tc.ClusterConstants.TEST_VERTEX_RAY_PR_ADDRESS
)
get_bearer_token_mock.assert_called_once_with()
ray_get_job_submission_client_cluster_info_mock.assert_called_once_with(
address=tc.ClusterConstants.TEST_VERTEX_RAY_JOB_CLIENT_IP
address=tc.ClusterConstants.TEST_VERTEX_RAY_DASHBOARD_ADDRESS,
_use_tls=True,
headers=tc.ClusterConstants.TEST_HEADERS,
)

@pytest.mark.usefixtures(
Expand All @@ -92,6 +98,7 @@ def test_job_submission_client_cluster_info_with_cluster_name(
self,
ray_get_job_submission_client_cluster_info_mock,
get_project_number_mock,
get_bearer_token_mock,
):
aiplatform.init(project=tc.ProjectConstants.TEST_GCP_PROJECT_ID)

Expand All @@ -101,8 +108,11 @@ def test_job_submission_client_cluster_info_with_cluster_name(
get_project_number_mock.assert_called_once_with(
name="projects/{}".format(tc.ProjectConstants.TEST_GCP_PROJECT_ID)
)
get_bearer_token_mock.assert_called_once_with()
ray_get_job_submission_client_cluster_info_mock.assert_called_once_with(
address=tc.ClusterConstants.TEST_VERTEX_RAY_JOB_CLIENT_IP
address=tc.ClusterConstants.TEST_VERTEX_RAY_DASHBOARD_ADDRESS,
_use_tls=True,
headers=tc.ClusterConstants.TEST_HEADERS,
)

@pytest.mark.usefixtures(
Expand Down

0 comments on commit 662d039

Please sign in to comment.