Skip to content

Commit

Permalink
feat: Add transport override to enable the use of REST instead of GRPC
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 611159115
  • Loading branch information
matthew29tang authored and Copybara-Service committed Feb 28, 2024
1 parent 02829f1 commit 6ab4084
Show file tree
Hide file tree
Showing 5 changed files with 202 additions and 55 deletions.
24 changes: 24 additions & 0 deletions google/cloud/aiplatform/initializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@ def __init__(self):
self._network = None
self._service_account = None
self._api_endpoint = None
self._api_transport = None

def init(
self,
Expand All @@ -121,6 +122,7 @@ def init(
network: Optional[str] = None,
service_account: Optional[str] = None,
api_endpoint: Optional[str] = None,
api_transport: Optional[str] = None,
):
"""Updates common initialization parameters with provided options.
Expand Down Expand Up @@ -179,6 +181,8 @@ def init(
api_endpoint (str):
Optional. The desired API endpoint,
e.g., us-central1-aiplatform.googleapis.com
api_transport (str):
Optional. The transport method which is either 'grpc' or 'rest'
Raises:
ValueError:
If experiment_description is provided but experiment is not.
Expand Down Expand Up @@ -231,6 +235,15 @@ def init(
backing_tensorboard=experiment_tensorboard,
)

if api_transport:
VALID_TRANSPORT_TYPES = ["grpc", "rest"]
if api_transport not in VALID_TRANSPORT_TYPES:
raise ValueError(
f"{api_transport} is not a valid transport type. "
+ f"Valid transport types: {VALID_TRANSPORT_TYPES}"
)
self._api_transport = api_transport

def get_encryption_spec(
self,
encryption_spec_key_name: Optional[str],
Expand Down Expand Up @@ -481,6 +494,17 @@ def create_client(
"client_info": client_info,
}

# Do not pass "grpc", rely on gapic defaults unless "rest" is specified
if self._api_transport == "rest":
if "Async" in client_class.__name__:
# Warn user that "rest" is not supported and use grpc instead
logging.warning(
"REST is not supported for async clients, "
+ "falling back to grpc."
)
else:
kwargs["transport"] = self._api_transport

return client_class(**kwargs)


Expand Down
31 changes: 25 additions & 6 deletions google/cloud/aiplatform/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -388,6 +388,7 @@ def __init__(
client_options: client_options.ClientOptions,
client_info: gapic_v1.client_info.ClientInfo,
credentials: Optional[auth_credentials.Credentials] = None,
transport: Optional[str] = None,
):
"""Stores parameters needed to instantiate client.
Expand All @@ -400,20 +401,30 @@ def __init__(
Required. Client info to pass to client.
credentials (auth_credentials.credentials):
Optional. Client credentials to pass to client.
transport (str):
Optional. Transport type to pass to client.
"""

self._client_class = client_class
self._credentials = credentials
self._client_options = client_options
self._client_info = client_info
self._api_transport = transport

def __getattr__(self, name: str) -> Any:
"""Instantiates client and returns attribute of the client."""
temporary_client = self._client_class(

kwargs = dict(
credentials=self._credentials,
client_options=self._client_options,
client_info=self._client_info,
)

if self._api_transport is not None:
kwargs["transport"] = self._api_transport

temporary_client = self._client_class(**kwargs)

return getattr(temporary_client, name)

@property
Expand Down Expand Up @@ -448,6 +459,7 @@ def __init__(
client_options: client_options.ClientOptions,
client_info: gapic_v1.client_info.ClientInfo,
credentials: Optional[auth_credentials.Credentials] = None,
transport: Optional[str] = None,
):
"""Stores parameters needed to instantiate client.
Expand All @@ -458,21 +470,28 @@ def __init__(
Required. Client info to pass to client.
credentials (auth_credentials.credentials):
Optional. Client credentials to pass to client.
transport (str):
Optional. Transport type to pass to client.
"""
kwargs = dict(
credentials=credentials,
client_options=client_options,
client_info=client_info,
)

if transport is not None:
kwargs["transport"] = transport

self._clients = {
version: self.WrappedClient(
client_class=client_class,
client_options=client_options,
client_info=client_info,
credentials=credentials,
transport=transport,
)
if self._is_temporary
else client_class(
client_options=client_options,
client_info=client_info,
credentials=credentials,
)
else client_class(**kwargs)
for version, client_class in self._version_map
}

Expand Down
Loading

0 comments on commit 6ab4084

Please sign in to comment.