Skip to content
This repository has been archived by the owner on Sep 5, 2023. It is now read-only.

feat: add context manager support in client #122

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -1172,6 +1172,12 @@ async def test_iam_permissions(
# Done; return the response.
return response

async def __aenter__(self):
return self

async def __aexit__(self, exc_type, exc, tb):
await self.transport.close()


try:
DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo(
Expand Down
18 changes: 14 additions & 4 deletions google/cloud/billing_v1/services/cloud_billing/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,10 +330,7 @@ def __init__(
client_cert_source_for_mtls=client_cert_source_func,
quota_project_id=client_options.quota_project_id,
client_info=client_info,
always_use_jwt_access=(
Transport == type(self).get_transport_class("grpc")
or Transport == type(self).get_transport_class("grpc_asyncio")
),
always_use_jwt_access=True,
)

def get_billing_account(
Expand Down Expand Up @@ -1261,6 +1258,19 @@ def test_iam_permissions(
# Done; return the response.
return response

def __enter__(self):
return self

def __exit__(self, type, value, traceback):
"""Releases underlying transport's resources.

.. warning::
ONLY use as a context manager if the transport is NOT shared
with other clients! Exiting the with block will CLOSE the transport
and may cause errors in other clients!
"""
self.transport.close()


try:
DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -296,6 +296,15 @@ def _prep_wrapped_messages(self, client_info):
),
}

def close(self):
"""Closes resources associated with the transport.

.. warning::
Only call this method if the transport is NOT shared
with other clients - this may cause errors in other clients!
"""
raise NotImplementedError()

@property
def get_billing_account(
self,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -578,5 +578,8 @@ def test_iam_permissions(
)
return self._stubs["test_iam_permissions"]

def close(self):
self.grpc_channel.close()


__all__ = ("CloudBillingGrpcTransport",)
Original file line number Diff line number Diff line change
Expand Up @@ -585,5 +585,8 @@ def test_iam_permissions(
)
return self._stubs["test_iam_permissions"]

def close(self):
return self.grpc_channel.close()


__all__ = ("CloudBillingGrpcAsyncIOTransport",)
Original file line number Diff line number Diff line change
Expand Up @@ -291,6 +291,12 @@ async def list_skus(
# Done; return the response.
return response

async def __aenter__(self):
return self

async def __aexit__(self, exc_type, exc, tb):
await self.transport.close()


try:
DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo(
Expand Down
18 changes: 14 additions & 4 deletions google/cloud/billing_v1/services/cloud_catalog/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -351,10 +351,7 @@ def __init__(
client_cert_source_for_mtls=client_cert_source_func,
quota_project_id=client_options.quota_project_id,
client_info=client_info,
always_use_jwt_access=(
Transport == type(self).get_transport_class("grpc")
or Transport == type(self).get_transport_class("grpc_asyncio")
),
always_use_jwt_access=True,
)

def list_services(
Expand Down Expand Up @@ -487,6 +484,19 @@ def list_skus(
# Done; return the response.
return response

def __enter__(self):
return self

def __exit__(self, type, value, traceback):
"""Releases underlying transport's resources.

.. warning::
ONLY use as a context manager if the transport is NOT shared
with other clients! Exiting the with block will CLOSE the transport
and may cause errors in other clients!
"""
self.transport.close()


try:
DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,15 @@ def _prep_wrapped_messages(self, client_info):
),
}

def close(self):
"""Closes resources associated with the transport.

.. warning::
Only call this method if the transport is NOT shared
with other clients - this may cause errors in other clients!
"""
raise NotImplementedError()

@property
def list_services(
self,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -282,5 +282,8 @@ def list_skus(
)
return self._stubs["list_skus"]

def close(self):
self.grpc_channel.close()


__all__ = ("CloudCatalogGrpcTransport",)
Original file line number Diff line number Diff line change
Expand Up @@ -288,5 +288,8 @@ def list_skus(
)
return self._stubs["list_skus"]

def close(self):
return self.grpc_channel.close()


__all__ = ("CloudCatalogGrpcAsyncIOTransport",)
9 changes: 9 additions & 0 deletions google/cloud/billing_v1/types/cloud_billing.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@ class ProjectBillingInfo(proto.Message):

class GetBillingAccountRequest(proto.Message):
r"""Request message for ``GetBillingAccount``.

Attributes:
name (str):
Required. The resource name of the billing account to
Expand All @@ -126,6 +127,7 @@ class GetBillingAccountRequest(proto.Message):

class ListBillingAccountsRequest(proto.Message):
r"""Request message for ``ListBillingAccounts``.

Attributes:
page_size (int):
Requested page size. The maximum page size is
Expand All @@ -152,6 +154,7 @@ class ListBillingAccountsRequest(proto.Message):

class ListBillingAccountsResponse(proto.Message):
r"""Response message for ``ListBillingAccounts``.

Attributes:
billing_accounts (Sequence[google.cloud.billing_v1.types.BillingAccount]):
A list of billing accounts.
Expand All @@ -174,6 +177,7 @@ def raw_page(self):

class CreateBillingAccountRequest(proto.Message):
r"""Request message for ``CreateBillingAccount``.

Attributes:
billing_account (google.cloud.billing_v1.types.BillingAccount):
Required. The billing account resource to
Expand All @@ -188,6 +192,7 @@ class CreateBillingAccountRequest(proto.Message):

class UpdateBillingAccountRequest(proto.Message):
r"""Request message for ``UpdateBillingAccount``.

Attributes:
name (str):
Required. The name of the billing account
Expand All @@ -209,6 +214,7 @@ class UpdateBillingAccountRequest(proto.Message):

class ListProjectBillingInfoRequest(proto.Message):
r"""Request message for ``ListProjectBillingInfo``.

Attributes:
name (str):
Required. The resource name of the billing account
Expand All @@ -231,6 +237,7 @@ class ListProjectBillingInfoRequest(proto.Message):

class ListProjectBillingInfoResponse(proto.Message):
r"""Request message for ``ListProjectBillingInfoResponse``.

Attributes:
project_billing_info (Sequence[google.cloud.billing_v1.types.ProjectBillingInfo]):
A list of ``ProjectBillingInfo`` resources representing the
Expand All @@ -254,6 +261,7 @@ def raw_page(self):

class GetProjectBillingInfoRequest(proto.Message):
r"""Request message for ``GetProjectBillingInfo``.

Attributes:
name (str):
Required. The resource name of the project for which billing
Expand All @@ -266,6 +274,7 @@ class GetProjectBillingInfoRequest(proto.Message):

class UpdateProjectBillingInfoRequest(proto.Message):
r"""Request message for ``UpdateProjectBillingInfo``.

Attributes:
name (str):
Required. The resource name of the project associated with
Expand Down
7 changes: 7 additions & 0 deletions google/cloud/billing_v1/types/cloud_catalog.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@

class Service(proto.Message):
r"""Encapsulates a single service in Google Cloud Platform.

Attributes:
name (str):
The resource name for the service.
Expand All @@ -62,6 +63,7 @@ class Service(proto.Message):

class Sku(proto.Message):
r"""Encapsulates a single SKU in Google Cloud Platform

Attributes:
name (str):
The resource name for the SKU.
Expand Down Expand Up @@ -101,6 +103,7 @@ class Sku(proto.Message):

class Category(proto.Message):
r"""Represents the category hierarchy of a SKU.

Attributes:
service_display_name (str):
The display name of the service this SKU
Expand Down Expand Up @@ -284,6 +287,7 @@ class AggregationInterval(proto.Enum):

class ListServicesRequest(proto.Message):
r"""Request message for ``ListServices``.

Attributes:
page_size (int):
Requested page size. Defaults to 5000.
Expand All @@ -300,6 +304,7 @@ class ListServicesRequest(proto.Message):

class ListServicesResponse(proto.Message):
r"""Response message for ``ListServices``.

Attributes:
services (Sequence[google.cloud.billing_v1.types.Service]):
A list of services.
Expand All @@ -320,6 +325,7 @@ def raw_page(self):

class ListSkusRequest(proto.Message):
r"""Request message for ``ListSkus``.

Attributes:
parent (str):
Required. The name of the service.
Expand Down Expand Up @@ -361,6 +367,7 @@ class ListSkusRequest(proto.Message):

class ListSkusResponse(proto.Message):
r"""Response message for ``ListSkus``.

Attributes:
skus (Sequence[google.cloud.billing_v1.types.Sku]):
The list of public SKUs of the given service.
Expand Down
50 changes: 50 additions & 0 deletions tests/unit/gapic/billing_v1/test_cloud_billing.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from google.api_core import gapic_v1
from google.api_core import grpc_helpers
from google.api_core import grpc_helpers_async
from google.api_core import path_template
from google.auth import credentials as ga_credentials
from google.auth.exceptions import MutualTLSChannelError
from google.cloud.billing_v1.services.cloud_billing import CloudBillingAsyncClient
Expand Down Expand Up @@ -3017,6 +3018,9 @@ def test_cloud_billing_base_transport():
with pytest.raises(NotImplementedError):
getattr(transport, method)(request=object())

with pytest.raises(NotImplementedError):
transport.close()


@requires_google_auth_gte_1_25_0
def test_cloud_billing_base_transport_with_credentials_file():
Expand Down Expand Up @@ -3464,3 +3468,49 @@ def test_client_withDEFAULT_CLIENT_INFO():
credentials=ga_credentials.AnonymousCredentials(), client_info=client_info,
)
prep.assert_called_once_with(client_info)


@pytest.mark.asyncio
async def test_transport_close_async():
client = CloudBillingAsyncClient(
credentials=ga_credentials.AnonymousCredentials(), transport="grpc_asyncio",
)
with mock.patch.object(
type(getattr(client.transport, "grpc_channel")), "close"
) as close:
async with client:
close.assert_not_called()
close.assert_called_once()


def test_transport_close():
transports = {
"grpc": "_grpc_channel",
}

for transport, close_name in transports.items():
client = CloudBillingClient(
credentials=ga_credentials.AnonymousCredentials(), transport=transport
)
with mock.patch.object(
type(getattr(client.transport, close_name)), "close"
) as close:
with client:
close.assert_not_called()
close.assert_called_once()


def test_client_ctx():
transports = [
"grpc",
]
for transport in transports:
client = CloudBillingClient(
credentials=ga_credentials.AnonymousCredentials(), transport=transport
)
# Test client calls underlying transport.
with mock.patch.object(type(client.transport), "close") as close:
close.assert_not_called()
with client:
pass
close.assert_called()
Loading