Skip to content

Commit

Permalink
feat: allow user-provided client info (#573)
Browse files Browse the repository at this point in the history
Fix for googleapis/python-kms#37, #566, and similar.
  • Loading branch information
software-dov committed Aug 17, 2020
1 parent 7c2bab7 commit b2e5274
Show file tree
Hide file tree
Showing 10 changed files with 145 additions and 36 deletions.
Expand Up @@ -23,7 +23,7 @@ from google.oauth2 import service_account # type: ignore
{% endfor -%}
{% endfor -%}
{% endfilter %}
from .transports.base import {{ service.name }}Transport
from .transports.base import {{ service.name }}Transport, DEFAULT_CLIENT_INFO
from .transports.grpc import {{ service.name }}GrpcTransport


Expand Down Expand Up @@ -135,6 +135,7 @@ class {{ service.client_name }}(metaclass={{ service.client_name }}Meta):
credentials: credentials.Credentials = None,
transport: Union[str, {{ service.name }}Transport] = None,
client_options: ClientOptions = None,
client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO,
) -> None:
"""Instantiate the {{ (service.client_name|snake_case).replace('_', ' ') }}.

Expand All @@ -160,6 +161,11 @@ class {{ service.client_name }}(metaclass={{ service.client_name }}Meta):
(2) The ``client_cert_source`` property is used to provide client
SSL credentials for mutual TLS transport. If not provided, the
default SSL credentials will be used if present.
client_info (google.api_core.gapic_v1.client_info.ClientInfo):
The client info used to send a user-agent string along with
API requests. If ``None``, then default info will be used.
Generally, you only need to set this if you're developing
your own client library.

Raises:
google.auth.exceptions.MutualTLSChannelError: If mutual TLS transport
Expand Down Expand Up @@ -209,6 +215,7 @@ class {{ service.client_name }}(metaclass={{ service.client_name }}Meta):
host=client_options.api_endpoint,
api_mtls_endpoint=client_options.api_endpoint,
client_cert_source=client_options.client_cert_source,
client_info=client_info,
)


Expand Down
Expand Up @@ -21,13 +21,13 @@ from google.auth import credentials # type: ignore
{% endfilter %}

try:
_client_info = gapic_v1.client_info.ClientInfo(
DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo(
gapic_version=pkg_resources.get_distribution(
'{{ api.naming.warehouse_package_name }}',
).version,
)
except pkg_resources.DistributionNotFound:
_client_info = gapic_v1.client_info.ClientInfo()
DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo()


class {{ service.name }}Transport(metaclass=abc.ABCMeta):
Expand All @@ -43,6 +43,7 @@ class {{ service.name }}Transport(metaclass=abc.ABCMeta):
self, *,
host: str{% if service.host %} = '{{ service.host }}'{% endif %},
credentials: credentials.Credentials = None,
client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO,
) -> None:
"""Instantiate the transport.

Expand All @@ -54,6 +55,11 @@ class {{ service.name }}Transport(metaclass=abc.ABCMeta):
credentials identify the application to the service; if none
are specified, the client will attempt to ascertain the
credentials from the environment.
client_info (google.api_core.gapic_v1.client_info.ClientInfo):
The client info used to send a user-agent string along with
API requests. If ``None``, then default info will be used.
Generally, you only need to set this if you're developing
your own client library.
"""
# Save the hostname. Default to port 443 (HTTPS) if none is specified.
if ':' not in host:
Expand All @@ -69,9 +75,9 @@ class {{ service.name }}Transport(metaclass=abc.ABCMeta):
self._credentials = credentials

# Lifted into its own function so it can be stubbed out during tests.
self._prep_wrapped_messages()
self._prep_wrapped_messages(client_info)

def _prep_wrapped_messages(self):
def _prep_wrapped_messages(self, client_info):
# Precomputed wrapped methods
self._wrapped_methods = {
{% for method in service.methods.values() -%}
Expand All @@ -92,7 +98,7 @@ class {{ service.name }}Transport(metaclass=abc.ABCMeta):
),
{%- endif %}
default_timeout={{ method.timeout }},
client_info=_client_info,
client_info=client_info,
),
{% endfor %} {# precomputed wrappers loop #}
}
Expand Down
Expand Up @@ -7,6 +7,7 @@ from google.api_core import grpc_helpers # type: ignore
{%- if service.has_lro %}
from google.api_core import operations_v1 # type: ignore
{%- endif %}
from google.api_core import gapic_v1 # type: ignore
from google import auth # type: ignore
from google.auth import credentials # type: ignore
from google.auth.transport.grpc import SslCredentials # type: ignore
Expand All @@ -20,7 +21,7 @@ import grpc # type: ignore
{{ method.output.ident.python_import }}
{% endfor -%}
{% endfilter %}
from .base import {{ service.name }}Transport
from .base import {{ service.name }}Transport, DEFAULT_CLIENT_INFO


class {{ service.name }}GrpcTransport({{ service.name }}Transport):
Expand All @@ -40,7 +41,9 @@ class {{ service.name }}GrpcTransport({{ service.name }}Transport):
credentials: credentials.Credentials = None,
channel: grpc.Channel = None,
api_mtls_endpoint: str = None,
client_cert_source: Callable[[], Tuple[bytes, bytes]] = None) -> None:
client_cert_source: Callable[[], Tuple[bytes, bytes]] = None,
client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO,
) -> None:
"""Instantiate the transport.

Args:
Expand All @@ -62,6 +65,11 @@ class {{ service.name }}GrpcTransport({{ service.name }}Transport):
callback to provide client SSL certificate bytes and private key
bytes, both in PEM format. It is ignored if ``api_mtls_endpoint``
is None.
client_info (google.api_core.gapic_v1.client_info.ClientInfo):
The client info used to send a user-agent string along with
API requests. If ``None``, then default info will be used.
Generally, you only need to set this if you're developing
your own client library.

Raises:
google.auth.exceptions.MutualTLSChannelError: If mutual TLS transport
Expand Down Expand Up @@ -101,7 +109,11 @@ class {{ service.name }}GrpcTransport({{ service.name }}Transport):
self._stubs = {} # type: Dict[str, Callable]

# Run the base constructor.
super().__init__(host=host, credentials=credentials)
super().__init__(
host=host,
credentials=credentials,
client_info=client_info,
)


@classmethod
Expand Down
Expand Up @@ -24,9 +24,7 @@ from google.api_core import future
from google.api_core import operations_v1
from google.longrunning import operations_pb2
{% endif -%}
{% if service.has_pagers -%}
from google.api_core import gapic_v1
{% endif -%}
{% for method in service.methods.values() -%}
{% for ref_type in method.ref_types
if not ((ref_type.ident.python_import.package == ('google', 'api_core') and ref_type.ident.python_import.module == 'operation')
Expand Down Expand Up @@ -109,6 +107,7 @@ def test_{{ service.client_name|snake_case }}_client_options():
client_cert_source=None,
credentials=None,
host="squid.clam.whelk",
client_info=transports.base.DEFAULT_CLIENT_INFO,
)

# Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS is
Expand All @@ -122,6 +121,7 @@ def test_{{ service.client_name|snake_case }}_client_options():
client_cert_source=None,
credentials=None,
host=client.DEFAULT_ENDPOINT,
client_info=transports.base.DEFAULT_CLIENT_INFO,
)

# Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS is
Expand All @@ -135,6 +135,7 @@ def test_{{ service.client_name|snake_case }}_client_options():
client_cert_source=None,
credentials=None,
host=client.DEFAULT_MTLS_ENDPOINT,
client_info=transports.base.DEFAULT_CLIENT_INFO,
)

# Check the case api_endpoint is not provided, GOOGLE_API_USE_MTLS is
Expand All @@ -149,6 +150,7 @@ def test_{{ service.client_name|snake_case }}_client_options():
client_cert_source=client_cert_source_callback,
credentials=None,
host=client.DEFAULT_MTLS_ENDPOINT,
client_info=transports.base.DEFAULT_CLIENT_INFO,
)

# Check the case api_endpoint is not provided, GOOGLE_API_USE_MTLS is
Expand All @@ -163,6 +165,7 @@ def test_{{ service.client_name|snake_case }}_client_options():
client_cert_source=None,
credentials=None,
host=client.DEFAULT_MTLS_ENDPOINT,
client_info=transports.base.DEFAULT_CLIENT_INFO,
)

# Check the case api_endpoint is not provided, GOOGLE_API_USE_MTLS is
Expand All @@ -177,6 +180,7 @@ def test_{{ service.client_name|snake_case }}_client_options():
client_cert_source=None,
credentials=None,
host=client.DEFAULT_ENDPOINT,
client_info=transports.base.DEFAULT_CLIENT_INFO,
)

# Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS has
Expand All @@ -197,6 +201,7 @@ def test_{{ service.client_name|snake_case }}_client_options_from_dict():
client_cert_source=None,
credentials=None,
host="squid.clam.whelk",
client_info=transports.base.DEFAULT_CLIENT_INFO,
)


Expand Down Expand Up @@ -769,4 +774,23 @@ def test_parse_{{ message.resource_type|snake_case }}_path():
{% endwith -%}
{% endfor -%}

def test_client_withDEFAULT_CLIENT_INFO():
client_info = gapic_v1.client_info.ClientInfo()

with mock.patch.object(transports.{{ service.name }}Transport, '_prep_wrapped_messages') as prep:
client = {{ service.client_name }}(
credentials=credentials.AnonymousCredentials(),
client_info=client_info,
)
prep.assert_called_once_with(client_info)

with mock.patch.object(transports.{{ service.name }}Transport, '_prep_wrapped_messages') as prep:
transport_class = {{ service.client_name }}.get_transport_class()
transport = transport_class(
credentials=credentials.AnonymousCredentials(),
client_info=client_info,
)
prep.assert_called_once_with(client_info)


{% endblock %}
Expand Up @@ -25,7 +25,7 @@ from google.iam.v1 import iam_policy_pb2 as iam_policy # type: ignore
from google.iam.v1 import policy_pb2 as policy # type: ignore
{% endif %}
{% endfilter %}
from .transports.base import {{ service.name }}Transport
from .transports.base import {{ service.name }}Transport, DEFAULT_CLIENT_INFO
from .transports.grpc_asyncio import {{ service.grpc_asyncio_transport_name }}
from .client import {{ service.client_name }}

Expand All @@ -52,6 +52,7 @@ class {{ service.async_client_name }}:
credentials: credentials.Credentials = None,
transport: Union[str, {{ service.name }}Transport] = 'grpc_asyncio',
client_options: ClientOptions = None,
client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO,
) -> None:
"""Instantiate the {{ (service.client_name|snake_case).replace('_', ' ') }}.

Expand Down Expand Up @@ -87,6 +88,8 @@ class {{ service.async_client_name }}:
credentials=credentials,
transport=transport,
client_options=client_options,
client_info=client_info,

)

{% for method in service.methods.values() -%}
Expand Down Expand Up @@ -202,7 +205,7 @@ class {{ service.async_client_name }}:
),
{%- endif %}
default_timeout={{ method.timeout }},
client_info=_client_info,
client_info=DEFAULT_CLIENT_INFO,
)
{%- if method.field_headers %}

Expand Down Expand Up @@ -352,7 +355,7 @@ class {{ service.async_client_name }}:
rpc = gapic_v1.method_async.wrap_method(
self._client._transport.set_iam_policy,
default_timeout=None,
client_info=_client_info,
client_info=DEFAULT_CLIENT_INFO,
)

# Certain fields should be provided within the metadata header;
Expand Down Expand Up @@ -459,7 +462,7 @@ class {{ service.async_client_name }}:
rpc = gapic_v1.method_async.wrap_method(
self._client._transport.get_iam_policy,
default_timeout=None,
client_info=_client_info,
client_info=DEFAULT_CLIENT_INFO,
)

# Certain fields should be provided within the metadata header;
Expand Down Expand Up @@ -510,7 +513,7 @@ class {{ service.async_client_name }}:
rpc = gapic_v1.method_async.wrap_method(
self._client._transport.test_iam_permissions,
default_timeout=None,
client_info=_client_info,
client_info=DEFAULT_CLIENT_INFO,
)

# Certain fields should be provided within the metadata header;
Expand All @@ -527,13 +530,13 @@ class {{ service.async_client_name }}:
{% endif %}

try:
_client_info = gapic_v1.client_info.ClientInfo(
DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo(
gapic_version=pkg_resources.get_distribution(
'{{ api.naming.warehouse_package_name }}',
).version,
)
except pkg_resources.DistributionNotFound:
_client_info = gapic_v1.client_info.ClientInfo()
DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo()


__all__ = (
Expand Down
Expand Up @@ -27,7 +27,7 @@ from google.iam.v1 import iam_policy_pb2 as iam_policy # type: ignore
from google.iam.v1 import policy_pb2 as policy # type: ignore
{% endif %}
{% endfilter %}
from .transports.base import {{ service.name }}Transport
from .transports.base import {{ service.name }}Transport, DEFAULT_CLIENT_INFO
from .transports.grpc import {{ service.grpc_transport_name }}
from .transports.grpc_asyncio import {{ service.grpc_asyncio_transport_name }}

Expand Down Expand Up @@ -141,6 +141,7 @@ class {{ service.client_name }}(metaclass={{ service.client_name }}Meta):
credentials: credentials.Credentials = None,
transport: Union[str, {{ service.name }}Transport] = None,
client_options: ClientOptions = None,
client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO,
) -> None:
"""Instantiate the {{ (service.client_name|snake_case).replace('_', ' ') }}.

Expand All @@ -166,7 +167,12 @@ class {{ service.client_name }}(metaclass={{ service.client_name }}Meta):
(2) The ``client_cert_source`` property is used to provide client
SSL credentials for mutual TLS transport. If not provided, the
default SSL credentials will be used if present.

client_info (google.api_core.gapic_v1.client_info.ClientInfo):
The client info used to send a user-agent string along with
API requests. If ``None``, then default info will be used.
Generally, you only need to set this if you're developing
your own client library.

Raises:
google.auth.exceptions.MutualTLSChannelError: If mutual TLS transport
creation failed for any reason.
Expand Down Expand Up @@ -219,6 +225,7 @@ class {{ service.client_name }}(metaclass={{ service.client_name }}Meta):
api_mtls_endpoint=client_options.api_endpoint,
client_cert_source=client_options.client_cert_source,
quota_project_id=client_options.quota_project_id,
client_info=client_info,
)


Expand Down Expand Up @@ -471,7 +478,7 @@ class {{ service.client_name }}(metaclass={{ service.client_name }}Meta):
rpc = gapic_v1.method.wrap_method(
self._transport.set_iam_policy,
default_timeout=None,
client_info=_client_info,
client_info=DEFAULT_CLIENT_INFO,
)

# Certain fields should be provided within the metadata header;
Expand Down Expand Up @@ -578,7 +585,7 @@ class {{ service.client_name }}(metaclass={{ service.client_name }}Meta):
rpc = gapic_v1.method.wrap_method(
self._transport.get_iam_policy,
default_timeout=None,
client_info=_client_info,
client_info=DEFAULT_CLIENT_INFO,
)

# Certain fields should be provided within the metadata header;
Expand Down Expand Up @@ -629,7 +636,7 @@ class {{ service.client_name }}(metaclass={{ service.client_name }}Meta):
rpc = gapic_v1.method.wrap_method(
self._transport.test_iam_permissions,
default_timeout=None,
client_info=_client_info,
client_info=DEFAULT_CLIENT_INFO,
)

# Certain fields should be provided within the metadata header;
Expand All @@ -647,13 +654,13 @@ class {{ service.client_name }}(metaclass={{ service.client_name }}Meta):


try:
_client_info = gapic_v1.client_info.ClientInfo(
DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo(
gapic_version=pkg_resources.get_distribution(
'{{ api.naming.warehouse_package_name }}',
).version,
)
except pkg_resources.DistributionNotFound:
_client_info = gapic_v1.client_info.ClientInfo()
DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo()


__all__ = (
Expand Down

0 comments on commit b2e5274

Please sign in to comment.