Skip to content

Commit

Permalink
feat: add context manager support in client (#637)
Browse files Browse the repository at this point in the history
- [ ] Regenerate this pull request now.

PiperOrigin-RevId: 408420890

Source-Link: googleapis/googleapis@2921f9f

Source-Link: googleapis/googleapis-gen@6598ca8
Copy-Tag: eyJwIjoiLmdpdGh1Yi8uT3dsQm90LnlhbWwiLCJoIjoiNjU5OGNhOGNiYmY1MjI2NzMzYTA5OWM0NTA2NTE4YTVhZjZmZjc0YyJ9

docs: list oneofs in docstring
fix(deps): require google-api-core >= 1.28.0
fix(deps): drop packaging dependency
feat: add context manager support in client
chore: fix docstring for first attribute of protos
fix: improper types in pagers generation
chore: use gapic-generator-python 0.56.2
  • Loading branch information
gcf-owl-bot[bot] authored Nov 10, 2021
1 parent 5e0c364 commit 5ae4be8
Show file tree
Hide file tree
Showing 35 changed files with 1,195 additions and 913 deletions.

Large diffs are not rendered by default.

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,13 @@
#
from typing import (
Any,
AsyncIterable,
AsyncIterator,
Awaitable,
Callable,
Iterable,
Sequence,
Tuple,
Optional,
Iterator,
)

from google.cloud.spanner_admin_database_v1.types import backup
Expand Down Expand Up @@ -76,14 +76,14 @@ def __getattr__(self, name: str) -> Any:
return getattr(self._response, name)

@property
def pages(self) -> Iterable[spanner_database_admin.ListDatabasesResponse]:
def pages(self) -> Iterator[spanner_database_admin.ListDatabasesResponse]:
yield self._response
while self._response.next_page_token:
self._request.page_token = self._response.next_page_token
self._response = self._method(self._request, metadata=self._metadata)
yield self._response

def __iter__(self) -> Iterable[spanner_database_admin.Database]:
def __iter__(self) -> Iterator[spanner_database_admin.Database]:
for page in self.pages:
yield from page.databases

Expand Down Expand Up @@ -140,14 +140,14 @@ def __getattr__(self, name: str) -> Any:
@property
async def pages(
self,
) -> AsyncIterable[spanner_database_admin.ListDatabasesResponse]:
) -> AsyncIterator[spanner_database_admin.ListDatabasesResponse]:
yield self._response
while self._response.next_page_token:
self._request.page_token = self._response.next_page_token
self._response = await self._method(self._request, metadata=self._metadata)
yield self._response

def __aiter__(self) -> AsyncIterable[spanner_database_admin.Database]:
def __aiter__(self) -> AsyncIterator[spanner_database_admin.Database]:
async def async_generator():
async for page in self.pages:
for response in page.databases:
Expand Down Expand Up @@ -206,14 +206,14 @@ def __getattr__(self, name: str) -> Any:
return getattr(self._response, name)

@property
def pages(self) -> Iterable[backup.ListBackupsResponse]:
def pages(self) -> Iterator[backup.ListBackupsResponse]:
yield self._response
while self._response.next_page_token:
self._request.page_token = self._response.next_page_token
self._response = self._method(self._request, metadata=self._metadata)
yield self._response

def __iter__(self) -> Iterable[backup.Backup]:
def __iter__(self) -> Iterator[backup.Backup]:
for page in self.pages:
yield from page.backups

Expand Down Expand Up @@ -268,14 +268,14 @@ def __getattr__(self, name: str) -> Any:
return getattr(self._response, name)

@property
async def pages(self) -> AsyncIterable[backup.ListBackupsResponse]:
async def pages(self) -> AsyncIterator[backup.ListBackupsResponse]:
yield self._response
while self._response.next_page_token:
self._request.page_token = self._response.next_page_token
self._response = await self._method(self._request, metadata=self._metadata)
yield self._response

def __aiter__(self) -> AsyncIterable[backup.Backup]:
def __aiter__(self) -> AsyncIterator[backup.Backup]:
async def async_generator():
async for page in self.pages:
for response in page.backups:
Expand Down Expand Up @@ -334,14 +334,14 @@ def __getattr__(self, name: str) -> Any:
return getattr(self._response, name)

@property
def pages(self) -> Iterable[spanner_database_admin.ListDatabaseOperationsResponse]:
def pages(self) -> Iterator[spanner_database_admin.ListDatabaseOperationsResponse]:
yield self._response
while self._response.next_page_token:
self._request.page_token = self._response.next_page_token
self._response = self._method(self._request, metadata=self._metadata)
yield self._response

def __iter__(self) -> Iterable[operations_pb2.Operation]:
def __iter__(self) -> Iterator[operations_pb2.Operation]:
for page in self.pages:
yield from page.operations

Expand Down Expand Up @@ -400,14 +400,14 @@ def __getattr__(self, name: str) -> Any:
@property
async def pages(
self,
) -> AsyncIterable[spanner_database_admin.ListDatabaseOperationsResponse]:
) -> AsyncIterator[spanner_database_admin.ListDatabaseOperationsResponse]:
yield self._response
while self._response.next_page_token:
self._request.page_token = self._response.next_page_token
self._response = await self._method(self._request, metadata=self._metadata)
yield self._response

def __aiter__(self) -> AsyncIterable[operations_pb2.Operation]:
def __aiter__(self) -> AsyncIterator[operations_pb2.Operation]:
async def async_generator():
async for page in self.pages:
for response in page.operations:
Expand Down Expand Up @@ -466,14 +466,14 @@ def __getattr__(self, name: str) -> Any:
return getattr(self._response, name)

@property
def pages(self) -> Iterable[backup.ListBackupOperationsResponse]:
def pages(self) -> Iterator[backup.ListBackupOperationsResponse]:
yield self._response
while self._response.next_page_token:
self._request.page_token = self._response.next_page_token
self._response = self._method(self._request, metadata=self._metadata)
yield self._response

def __iter__(self) -> Iterable[operations_pb2.Operation]:
def __iter__(self) -> Iterator[operations_pb2.Operation]:
for page in self.pages:
yield from page.operations

Expand Down Expand Up @@ -528,14 +528,14 @@ def __getattr__(self, name: str) -> Any:
return getattr(self._response, name)

@property
async def pages(self) -> AsyncIterable[backup.ListBackupOperationsResponse]:
async def pages(self) -> AsyncIterator[backup.ListBackupOperationsResponse]:
yield self._response
while self._response.next_page_token:
self._request.page_token = self._response.next_page_token
self._response = await self._method(self._request, metadata=self._metadata)
yield self._response

def __aiter__(self) -> AsyncIterable[operations_pb2.Operation]:
def __aiter__(self) -> AsyncIterator[operations_pb2.Operation]:
async def async_generator():
async for page in self.pages:
for response in page.operations:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,14 @@
#
import abc
from typing import Awaitable, Callable, Dict, Optional, Sequence, Union
import packaging.version
import pkg_resources

import google.auth # type: ignore
import google.api_core # type: ignore
from google.api_core import exceptions as core_exceptions # type: ignore
from google.api_core import gapic_v1 # type: ignore
from google.api_core import retry as retries # type: ignore
from google.api_core import operations_v1 # type: ignore
import google.api_core
from google.api_core import exceptions as core_exceptions
from google.api_core import gapic_v1
from google.api_core import retry as retries
from google.api_core import operations_v1
from google.auth import credentials as ga_credentials # type: ignore
from google.oauth2 import service_account # type: ignore

Expand All @@ -44,15 +43,6 @@
except pkg_resources.DistributionNotFound:
DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo()

try:
# google.auth.__version__ was added in 1.26.0
_GOOGLE_AUTH_VERSION = google.auth.__version__
except AttributeError:
try: # try pkg_resources if it is available
_GOOGLE_AUTH_VERSION = pkg_resources.get_distribution("google-auth").version
except pkg_resources.DistributionNotFound: # pragma: NO COVER
_GOOGLE_AUTH_VERSION = None


class DatabaseAdminTransport(abc.ABC):
"""Abstract transport class for DatabaseAdmin."""
Expand Down Expand Up @@ -105,7 +95,7 @@ def __init__(
host += ":443"
self._host = host

scopes_kwargs = self._get_scopes_kwargs(self._host, scopes)
scopes_kwargs = {"scopes": scopes, "default_scopes": self.AUTH_SCOPES}

# Save the scopes.
self._scopes = scopes
Expand All @@ -127,7 +117,7 @@ def __init__(
**scopes_kwargs, quota_project_id=quota_project_id
)

# If the credentials is service account credentials, then always try to use self signed JWT.
# If the credentials are service account credentials, then always try to use self signed JWT.
if (
always_use_jwt_access
and isinstance(credentials, service_account.Credentials)
Expand All @@ -138,29 +128,6 @@ def __init__(
# Save the credentials.
self._credentials = credentials

# TODO(busunkim): This method is in the base transport
# to avoid duplicating code across the transport classes. These functions
# should be deleted once the minimum required versions of google-auth is increased.

# TODO: Remove this function once google-auth >= 1.25.0 is required
@classmethod
def _get_scopes_kwargs(
cls, host: str, scopes: Optional[Sequence[str]]
) -> Dict[str, Optional[Sequence[str]]]:
"""Returns scopes kwargs to pass to google-auth methods depending on the google-auth version"""

scopes_kwargs = {}

if _GOOGLE_AUTH_VERSION and (
packaging.version.parse(_GOOGLE_AUTH_VERSION)
>= packaging.version.parse("1.25.0")
):
scopes_kwargs = {"scopes": scopes, "default_scopes": cls.AUTH_SCOPES}
else:
scopes_kwargs = {"scopes": scopes or cls.AUTH_SCOPES}

return scopes_kwargs

def _prep_wrapped_messages(self, client_info):
# Precompute the wrapped methods.
self._wrapped_methods = {
Expand Down Expand Up @@ -363,8 +330,17 @@ def _prep_wrapped_messages(self, client_info):
),
}

def close(self):
"""Closes resources associated with the transport.
.. warning::
Only call this method if the transport is NOT shared
with other clients - this may cause errors in other clients!
"""
raise NotImplementedError()

@property
def operations_client(self) -> operations_v1.OperationsClient:
def operations_client(self):
"""Return the client designed to process long-running operations."""
raise NotImplementedError()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,9 @@
import warnings
from typing import Callable, Dict, Optional, Sequence, Tuple, Union

from google.api_core import grpc_helpers # type: ignore
from google.api_core import operations_v1 # type: ignore
from google.api_core import gapic_v1 # type: ignore
from google.api_core import grpc_helpers
from google.api_core import operations_v1
from google.api_core import gapic_v1
import google.auth # type: ignore
from google.auth import credentials as ga_credentials # type: ignore
from google.auth.transport.grpc import SslCredentials # type: ignore
Expand Down Expand Up @@ -92,16 +92,16 @@ def __init__(
api_mtls_endpoint (Optional[str]): Deprecated. The mutual TLS endpoint.
If provided, it overrides the ``host`` argument and tries to create
a mutual TLS channel with client SSL credentials from
``client_cert_source`` or applicatin default SSL credentials.
``client_cert_source`` or application default SSL credentials.
client_cert_source (Optional[Callable[[], Tuple[bytes, bytes]]]):
Deprecated. A callback to provide client SSL certificate bytes and
private key bytes, both in PEM format. It is ignored if
``api_mtls_endpoint`` is None.
ssl_channel_credentials (grpc.ChannelCredentials): SSL credentials
for grpc channel. It is ignored if ``channel`` is provided.
for the grpc channel. It is ignored if ``channel`` is provided.
client_cert_source_for_mtls (Optional[Callable[[], Tuple[bytes, bytes]]]):
A callback to provide client certificate bytes and private key bytes,
both in PEM format. It is used to configure mutual TLS channel. It is
both in PEM format. It is used to configure a mutual TLS channel. It is
ignored if ``channel`` or ``ssl_channel_credentials`` is provided.
quota_project_id (Optional[str]): An optional project to use for billing
and quota.
Expand All @@ -122,7 +122,7 @@ def __init__(
self._grpc_channel = None
self._ssl_channel_credentials = ssl_channel_credentials
self._stubs: Dict[str, Callable] = {}
self._operations_client = None
self._operations_client: Optional[operations_v1.OperationsClient] = None

if api_mtls_endpoint:
warnings.warn("api_mtls_endpoint is deprecated", DeprecationWarning)
Expand Down Expand Up @@ -815,5 +815,8 @@ def list_backup_operations(
)
return self._stubs["list_backup_operations"]

def close(self):
self.grpc_channel.close()


__all__ = ("DatabaseAdminGrpcTransport",)
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,11 @@
import warnings
from typing import Awaitable, Callable, Dict, Optional, Sequence, Tuple, Union

from google.api_core import gapic_v1 # type: ignore
from google.api_core import grpc_helpers_async # type: ignore
from google.api_core import operations_v1 # type: ignore
from google.api_core import gapic_v1
from google.api_core import grpc_helpers_async
from google.api_core import operations_v1
from google.auth import credentials as ga_credentials # type: ignore
from google.auth.transport.grpc import SslCredentials # type: ignore
import packaging.version

import grpc # type: ignore
from grpc.experimental import aio # type: ignore
Expand Down Expand Up @@ -139,16 +138,16 @@ def __init__(
api_mtls_endpoint (Optional[str]): Deprecated. The mutual TLS endpoint.
If provided, it overrides the ``host`` argument and tries to create
a mutual TLS channel with client SSL credentials from
``client_cert_source`` or applicatin default SSL credentials.
``client_cert_source`` or application default SSL credentials.
client_cert_source (Optional[Callable[[], Tuple[bytes, bytes]]]):
Deprecated. A callback to provide client SSL certificate bytes and
private key bytes, both in PEM format. It is ignored if
``api_mtls_endpoint`` is None.
ssl_channel_credentials (grpc.ChannelCredentials): SSL credentials
for grpc channel. It is ignored if ``channel`` is provided.
for the grpc channel. It is ignored if ``channel`` is provided.
client_cert_source_for_mtls (Optional[Callable[[], Tuple[bytes, bytes]]]):
A callback to provide client certificate bytes and private key bytes,
both in PEM format. It is used to configure mutual TLS channel. It is
both in PEM format. It is used to configure a mutual TLS channel. It is
ignored if ``channel`` or ``ssl_channel_credentials`` is provided.
quota_project_id (Optional[str]): An optional project to use for billing
and quota.
Expand All @@ -169,7 +168,7 @@ def __init__(
self._grpc_channel = None
self._ssl_channel_credentials = ssl_channel_credentials
self._stubs: Dict[str, Callable] = {}
self._operations_client = None
self._operations_client: Optional[operations_v1.OperationsAsyncClient] = None

if api_mtls_endpoint:
warnings.warn("api_mtls_endpoint is deprecated", DeprecationWarning)
Expand Down Expand Up @@ -833,5 +832,8 @@ def list_backup_operations(
)
return self._stubs["list_backup_operations"]

def close(self):
return self.grpc_channel.close()


__all__ = ("DatabaseAdminGrpcAsyncIOTransport",)
3 changes: 3 additions & 0 deletions google/cloud/spanner_admin_database_v1/types/backup.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@

class Backup(proto.Message):
r"""A backup of a Cloud Spanner database.
Attributes:
database (str):
Required for the
Expand Down Expand Up @@ -461,6 +462,7 @@ def raw_page(self):

class BackupInfo(proto.Message):
r"""Information about a backup.
Attributes:
backup (str):
Name of the backup.
Expand Down Expand Up @@ -491,6 +493,7 @@ class BackupInfo(proto.Message):

class CreateBackupEncryptionConfig(proto.Message):
r"""Encryption configuration for the backup to create.
Attributes:
encryption_type (google.cloud.spanner_admin_database_v1.types.CreateBackupEncryptionConfig.EncryptionType):
Required. The encryption type of the backup.
Expand Down
1 change: 1 addition & 0 deletions google/cloud/spanner_admin_database_v1/types/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ class OperationProgress(proto.Message):

class EncryptionConfig(proto.Message):
r"""Encryption configuration for a Cloud Spanner database.
Attributes:
kms_key_name (str):
The Cloud KMS key to be used for encrypting and decrypting
Expand Down
Loading

0 comments on commit 5ae4be8

Please sign in to comment.