Navigation Menu

Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add context manager support in client #637

Merged
merged 2 commits into from Nov 10, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view

Large diffs are not rendered by default.

Large diffs are not rendered by default.

Expand Up @@ -15,13 +15,13 @@
#
from typing import (
Any,
AsyncIterable,
AsyncIterator,
Awaitable,
Callable,
Iterable,
Sequence,
Tuple,
Optional,
Iterator,
)

from google.cloud.spanner_admin_database_v1.types import backup
Expand Down Expand Up @@ -76,14 +76,14 @@ def __getattr__(self, name: str) -> Any:
return getattr(self._response, name)

@property
def pages(self) -> Iterable[spanner_database_admin.ListDatabasesResponse]:
def pages(self) -> Iterator[spanner_database_admin.ListDatabasesResponse]:
yield self._response
while self._response.next_page_token:
self._request.page_token = self._response.next_page_token
self._response = self._method(self._request, metadata=self._metadata)
yield self._response

def __iter__(self) -> Iterable[spanner_database_admin.Database]:
def __iter__(self) -> Iterator[spanner_database_admin.Database]:
for page in self.pages:
yield from page.databases

Expand Down Expand Up @@ -140,14 +140,14 @@ def __getattr__(self, name: str) -> Any:
@property
async def pages(
self,
) -> AsyncIterable[spanner_database_admin.ListDatabasesResponse]:
) -> AsyncIterator[spanner_database_admin.ListDatabasesResponse]:
yield self._response
while self._response.next_page_token:
self._request.page_token = self._response.next_page_token
self._response = await self._method(self._request, metadata=self._metadata)
yield self._response

def __aiter__(self) -> AsyncIterable[spanner_database_admin.Database]:
def __aiter__(self) -> AsyncIterator[spanner_database_admin.Database]:
async def async_generator():
async for page in self.pages:
for response in page.databases:
Expand Down Expand Up @@ -206,14 +206,14 @@ def __getattr__(self, name: str) -> Any:
return getattr(self._response, name)

@property
def pages(self) -> Iterable[backup.ListBackupsResponse]:
def pages(self) -> Iterator[backup.ListBackupsResponse]:
yield self._response
while self._response.next_page_token:
self._request.page_token = self._response.next_page_token
self._response = self._method(self._request, metadata=self._metadata)
yield self._response

def __iter__(self) -> Iterable[backup.Backup]:
def __iter__(self) -> Iterator[backup.Backup]:
for page in self.pages:
yield from page.backups

Expand Down Expand Up @@ -268,14 +268,14 @@ def __getattr__(self, name: str) -> Any:
return getattr(self._response, name)

@property
async def pages(self) -> AsyncIterable[backup.ListBackupsResponse]:
async def pages(self) -> AsyncIterator[backup.ListBackupsResponse]:
yield self._response
while self._response.next_page_token:
self._request.page_token = self._response.next_page_token
self._response = await self._method(self._request, metadata=self._metadata)
yield self._response

def __aiter__(self) -> AsyncIterable[backup.Backup]:
def __aiter__(self) -> AsyncIterator[backup.Backup]:
async def async_generator():
async for page in self.pages:
for response in page.backups:
Expand Down Expand Up @@ -334,14 +334,14 @@ def __getattr__(self, name: str) -> Any:
return getattr(self._response, name)

@property
def pages(self) -> Iterable[spanner_database_admin.ListDatabaseOperationsResponse]:
def pages(self) -> Iterator[spanner_database_admin.ListDatabaseOperationsResponse]:
yield self._response
while self._response.next_page_token:
self._request.page_token = self._response.next_page_token
self._response = self._method(self._request, metadata=self._metadata)
yield self._response

def __iter__(self) -> Iterable[operations_pb2.Operation]:
def __iter__(self) -> Iterator[operations_pb2.Operation]:
for page in self.pages:
yield from page.operations

Expand Down Expand Up @@ -400,14 +400,14 @@ def __getattr__(self, name: str) -> Any:
@property
async def pages(
self,
) -> AsyncIterable[spanner_database_admin.ListDatabaseOperationsResponse]:
) -> AsyncIterator[spanner_database_admin.ListDatabaseOperationsResponse]:
yield self._response
while self._response.next_page_token:
self._request.page_token = self._response.next_page_token
self._response = await self._method(self._request, metadata=self._metadata)
yield self._response

def __aiter__(self) -> AsyncIterable[operations_pb2.Operation]:
def __aiter__(self) -> AsyncIterator[operations_pb2.Operation]:
async def async_generator():
async for page in self.pages:
for response in page.operations:
Expand Down Expand Up @@ -466,14 +466,14 @@ def __getattr__(self, name: str) -> Any:
return getattr(self._response, name)

@property
def pages(self) -> Iterable[backup.ListBackupOperationsResponse]:
def pages(self) -> Iterator[backup.ListBackupOperationsResponse]:
yield self._response
while self._response.next_page_token:
self._request.page_token = self._response.next_page_token
self._response = self._method(self._request, metadata=self._metadata)
yield self._response

def __iter__(self) -> Iterable[operations_pb2.Operation]:
def __iter__(self) -> Iterator[operations_pb2.Operation]:
for page in self.pages:
yield from page.operations

Expand Down Expand Up @@ -528,14 +528,14 @@ def __getattr__(self, name: str) -> Any:
return getattr(self._response, name)

@property
async def pages(self) -> AsyncIterable[backup.ListBackupOperationsResponse]:
async def pages(self) -> AsyncIterator[backup.ListBackupOperationsResponse]:
yield self._response
while self._response.next_page_token:
self._request.page_token = self._response.next_page_token
self._response = await self._method(self._request, metadata=self._metadata)
yield self._response

def __aiter__(self) -> AsyncIterable[operations_pb2.Operation]:
def __aiter__(self) -> AsyncIterator[operations_pb2.Operation]:
async def async_generator():
async for page in self.pages:
for response in page.operations:
Expand Down
Expand Up @@ -15,15 +15,14 @@
#
import abc
from typing import Awaitable, Callable, Dict, Optional, Sequence, Union
import packaging.version
import pkg_resources

import google.auth # type: ignore
import google.api_core # type: ignore
from google.api_core import exceptions as core_exceptions # type: ignore
from google.api_core import gapic_v1 # type: ignore
from google.api_core import retry as retries # type: ignore
from google.api_core import operations_v1 # type: ignore
import google.api_core
from google.api_core import exceptions as core_exceptions
from google.api_core import gapic_v1
from google.api_core import retry as retries
from google.api_core import operations_v1
from google.auth import credentials as ga_credentials # type: ignore
from google.oauth2 import service_account # type: ignore

Expand All @@ -44,15 +43,6 @@
except pkg_resources.DistributionNotFound:
DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo()

try:
# google.auth.__version__ was added in 1.26.0
_GOOGLE_AUTH_VERSION = google.auth.__version__
except AttributeError:
try: # try pkg_resources if it is available
_GOOGLE_AUTH_VERSION = pkg_resources.get_distribution("google-auth").version
except pkg_resources.DistributionNotFound: # pragma: NO COVER
_GOOGLE_AUTH_VERSION = None


class DatabaseAdminTransport(abc.ABC):
"""Abstract transport class for DatabaseAdmin."""
Expand Down Expand Up @@ -105,7 +95,7 @@ def __init__(
host += ":443"
self._host = host

scopes_kwargs = self._get_scopes_kwargs(self._host, scopes)
scopes_kwargs = {"scopes": scopes, "default_scopes": self.AUTH_SCOPES}

# Save the scopes.
self._scopes = scopes
Expand All @@ -127,7 +117,7 @@ def __init__(
**scopes_kwargs, quota_project_id=quota_project_id
)

# If the credentials is service account credentials, then always try to use self signed JWT.
# If the credentials are service account credentials, then always try to use self signed JWT.
if (
always_use_jwt_access
and isinstance(credentials, service_account.Credentials)
Expand All @@ -138,29 +128,6 @@ def __init__(
# Save the credentials.
self._credentials = credentials

# TODO(busunkim): This method is in the base transport
# to avoid duplicating code across the transport classes. These functions
# should be deleted once the minimum required versions of google-auth is increased.

# TODO: Remove this function once google-auth >= 1.25.0 is required
@classmethod
def _get_scopes_kwargs(
cls, host: str, scopes: Optional[Sequence[str]]
) -> Dict[str, Optional[Sequence[str]]]:
"""Returns scopes kwargs to pass to google-auth methods depending on the google-auth version"""

scopes_kwargs = {}

if _GOOGLE_AUTH_VERSION and (
packaging.version.parse(_GOOGLE_AUTH_VERSION)
>= packaging.version.parse("1.25.0")
):
scopes_kwargs = {"scopes": scopes, "default_scopes": cls.AUTH_SCOPES}
else:
scopes_kwargs = {"scopes": scopes or cls.AUTH_SCOPES}

return scopes_kwargs

def _prep_wrapped_messages(self, client_info):
# Precompute the wrapped methods.
self._wrapped_methods = {
Expand Down Expand Up @@ -363,8 +330,17 @@ def _prep_wrapped_messages(self, client_info):
),
}

def close(self):
"""Closes resources associated with the transport.

.. warning::
Only call this method if the transport is NOT shared
with other clients - this may cause errors in other clients!
"""
raise NotImplementedError()

@property
def operations_client(self) -> operations_v1.OperationsClient:
def operations_client(self):
"""Return the client designed to process long-running operations."""
raise NotImplementedError()

Expand Down
Expand Up @@ -16,9 +16,9 @@
import warnings
from typing import Callable, Dict, Optional, Sequence, Tuple, Union

from google.api_core import grpc_helpers # type: ignore
from google.api_core import operations_v1 # type: ignore
from google.api_core import gapic_v1 # type: ignore
from google.api_core import grpc_helpers
from google.api_core import operations_v1
from google.api_core import gapic_v1
import google.auth # type: ignore
from google.auth import credentials as ga_credentials # type: ignore
from google.auth.transport.grpc import SslCredentials # type: ignore
Expand Down Expand Up @@ -92,16 +92,16 @@ def __init__(
api_mtls_endpoint (Optional[str]): Deprecated. The mutual TLS endpoint.
If provided, it overrides the ``host`` argument and tries to create
a mutual TLS channel with client SSL credentials from
``client_cert_source`` or applicatin default SSL credentials.
``client_cert_source`` or application default SSL credentials.
client_cert_source (Optional[Callable[[], Tuple[bytes, bytes]]]):
Deprecated. A callback to provide client SSL certificate bytes and
private key bytes, both in PEM format. It is ignored if
``api_mtls_endpoint`` is None.
ssl_channel_credentials (grpc.ChannelCredentials): SSL credentials
for grpc channel. It is ignored if ``channel`` is provided.
for the grpc channel. It is ignored if ``channel`` is provided.
client_cert_source_for_mtls (Optional[Callable[[], Tuple[bytes, bytes]]]):
A callback to provide client certificate bytes and private key bytes,
both in PEM format. It is used to configure mutual TLS channel. It is
both in PEM format. It is used to configure a mutual TLS channel. It is
ignored if ``channel`` or ``ssl_channel_credentials`` is provided.
quota_project_id (Optional[str]): An optional project to use for billing
and quota.
Expand All @@ -122,7 +122,7 @@ def __init__(
self._grpc_channel = None
self._ssl_channel_credentials = ssl_channel_credentials
self._stubs: Dict[str, Callable] = {}
self._operations_client = None
self._operations_client: Optional[operations_v1.OperationsClient] = None

if api_mtls_endpoint:
warnings.warn("api_mtls_endpoint is deprecated", DeprecationWarning)
Expand Down Expand Up @@ -815,5 +815,8 @@ def list_backup_operations(
)
return self._stubs["list_backup_operations"]

def close(self):
self.grpc_channel.close()


__all__ = ("DatabaseAdminGrpcTransport",)
Expand Up @@ -16,12 +16,11 @@
import warnings
from typing import Awaitable, Callable, Dict, Optional, Sequence, Tuple, Union

from google.api_core import gapic_v1 # type: ignore
from google.api_core import grpc_helpers_async # type: ignore
from google.api_core import operations_v1 # type: ignore
from google.api_core import gapic_v1
from google.api_core import grpc_helpers_async
from google.api_core import operations_v1
from google.auth import credentials as ga_credentials # type: ignore
from google.auth.transport.grpc import SslCredentials # type: ignore
import packaging.version

import grpc # type: ignore
from grpc.experimental import aio # type: ignore
Expand Down Expand Up @@ -139,16 +138,16 @@ def __init__(
api_mtls_endpoint (Optional[str]): Deprecated. The mutual TLS endpoint.
If provided, it overrides the ``host`` argument and tries to create
a mutual TLS channel with client SSL credentials from
``client_cert_source`` or applicatin default SSL credentials.
``client_cert_source`` or application default SSL credentials.
client_cert_source (Optional[Callable[[], Tuple[bytes, bytes]]]):
Deprecated. A callback to provide client SSL certificate bytes and
private key bytes, both in PEM format. It is ignored if
``api_mtls_endpoint`` is None.
ssl_channel_credentials (grpc.ChannelCredentials): SSL credentials
for grpc channel. It is ignored if ``channel`` is provided.
for the grpc channel. It is ignored if ``channel`` is provided.
client_cert_source_for_mtls (Optional[Callable[[], Tuple[bytes, bytes]]]):
A callback to provide client certificate bytes and private key bytes,
both in PEM format. It is used to configure mutual TLS channel. It is
both in PEM format. It is used to configure a mutual TLS channel. It is
ignored if ``channel`` or ``ssl_channel_credentials`` is provided.
quota_project_id (Optional[str]): An optional project to use for billing
and quota.
Expand All @@ -169,7 +168,7 @@ def __init__(
self._grpc_channel = None
self._ssl_channel_credentials = ssl_channel_credentials
self._stubs: Dict[str, Callable] = {}
self._operations_client = None
self._operations_client: Optional[operations_v1.OperationsAsyncClient] = None

if api_mtls_endpoint:
warnings.warn("api_mtls_endpoint is deprecated", DeprecationWarning)
Expand Down Expand Up @@ -833,5 +832,8 @@ def list_backup_operations(
)
return self._stubs["list_backup_operations"]

def close(self):
return self.grpc_channel.close()


__all__ = ("DatabaseAdminGrpcAsyncIOTransport",)
3 changes: 3 additions & 0 deletions google/cloud/spanner_admin_database_v1/types/backup.py
Expand Up @@ -42,6 +42,7 @@

class Backup(proto.Message):
r"""A backup of a Cloud Spanner database.

Attributes:
database (str):
Required for the
Expand Down Expand Up @@ -461,6 +462,7 @@ def raw_page(self):

class BackupInfo(proto.Message):
r"""Information about a backup.

Attributes:
backup (str):
Name of the backup.
Expand Down Expand Up @@ -491,6 +493,7 @@ class BackupInfo(proto.Message):

class CreateBackupEncryptionConfig(proto.Message):
r"""Encryption configuration for the backup to create.

Attributes:
encryption_type (google.cloud.spanner_admin_database_v1.types.CreateBackupEncryptionConfig.EncryptionType):
Required. The encryption type of the backup.
Expand Down
1 change: 1 addition & 0 deletions google/cloud/spanner_admin_database_v1/types/common.py
Expand Up @@ -47,6 +47,7 @@ class OperationProgress(proto.Message):

class EncryptionConfig(proto.Message):
r"""Encryption configuration for a Cloud Spanner database.

Attributes:
kms_key_name (str):
The Cloud KMS key to be used for encrypting and decrypting
Expand Down