Skip to content

Commit

Permalink
feat: Support public endpoint for Ray Client
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 630181847
  • Loading branch information
yinghsienwu authored and Copybara-Service committed May 2, 2024
1 parent 4ce2f60 commit 57a5f78
Showing 1 changed file with 24 additions and 2 deletions.
26 changes: 24 additions & 2 deletions google/cloud/aiplatform/preview/vertex_ray/client_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
import grpc
import logging
from typing import Dict
from typing import Optional
from google.cloud import aiplatform
from google.cloud.aiplatform import initializer
from ray import client_builder
from .render import VertexRayTemplate
from .util import _validation_utils
Expand Down Expand Up @@ -80,7 +82,8 @@ class VertexRayClientBuilder(client_builder.ClientBuilder):
def __init__(self, address: Optional[str]) -> None:
address = _validation_utils.maybe_reconstruct_resource_name(address)
_validation_utils.valid_resource_name(address)

self._credentials = None
self._metadata = None
self.vertex_address = address
logging.info(
"[Ray on Vertex AI]: Using cluster resource name to access head address with GAPIC API"
Expand All @@ -89,9 +92,17 @@ def __init__(self, address: Optional[str]) -> None:
self.resource_name = address

self.response = _gapic_utils.get_persistent_resource(self.resource_name)
address = self.response.resource_runtime.access_uris.get(
private_address = self.response.resource_runtime.access_uris.get(
"RAY_HEAD_NODE_INTERNAL_IP"
)
public_address = self.response.resource_runtime.access_uris.get(
"RAY_CLIENT_ENDPOINT"
)
if public_address is None:
address = private_address
else:
address = public_address

if address is None:
persistent_resource_id = self.resource_name.split("/")[5]
raise ValueError(
Expand Down Expand Up @@ -143,6 +154,17 @@ def __init__(self, address: Optional[str]) -> None:
def connect(self) -> _VertexRayClientContext:
# Can send any other params to ray cluster here
logging.info("[Ray on Vertex AI]: Connecting...")

public_address = self.response.resource_runtime.access_uris.get(
"RAY_CLIENT_ENDPOINT"
)
if public_address:
self._credentials = grpc.ssl_channel_credentials()
bearer_token = _validation_utils.get_bearer_token()
self._metadata = [
("authorization", "Bearer {}".format(bearer_token)),
("x-goog-user-project", "{}".format(initializer.global_config.project)),
]
ray_client_context = super().connect()
ray_head_uris = self.response.resource_runtime.access_uris

Expand Down

0 comments on commit 57a5f78

Please sign in to comment.