diff --git a/gapic/ads-templates/%namespace/%name/%version/%sub/services/%service/client.py.j2 b/gapic/ads-templates/%namespace/%name/%version/%sub/services/%service/client.py.j2 index 12e92de7d1..af3a3b76d5 100644 --- a/gapic/ads-templates/%namespace/%name/%version/%sub/services/%service/client.py.j2 +++ b/gapic/ads-templates/%namespace/%name/%version/%sub/services/%service/client.py.j2 @@ -2,6 +2,7 @@ {% block content %} from collections import OrderedDict +import os import re from typing import Callable, Dict, {% if service.any_server_streaming %}Iterable, {% endif %}{% if service.any_client_streaming %}Iterator, {% endif %}Sequence, Tuple, Type, Union import pkg_resources @@ -11,6 +12,8 @@ from google.api_core import exceptions # type: ignore from google.api_core import gapic_v1 # type: ignore from google.api_core import retry as retries # type: ignore from google.auth import credentials # type: ignore +from google.auth.transport import mtls # type: ignore +from google.auth.exceptions import MutualTLSChannelError # type: ignore from google.oauth2 import service_account # type: ignore {% filter sort_lines -%} @@ -144,21 +147,47 @@ class {{ service.client_name }}(metaclass={{ service.client_name }}Meta): transport (Union[str, ~.{{ service.name }}Transport]): The transport to use. If set to None, a transport is chosen automatically. - client_options (ClientOptions): Custom options for the client. + client_options (ClientOptions): Custom options for the client. It + won't take effect unless ``transport`` is None. (1) The ``api_endpoint`` property can be used to override the - default endpoint provided by the client. - (2) If ``transport`` argument is None, ``client_options`` can be - used to create a mutual TLS transport. If ``client_cert_source`` - is provided, mutual TLS transport will be created with the given - ``api_endpoint`` or the default mTLS endpoint, and the client - SSL credentials obtained from ``client_cert_source``. + default endpoint provided by the client. GOOGLE_API_USE_MTLS + environment variable can also be used to override the endpoint: + "Always" (always use the default mTLS endpoint), "Never" (always + use the default regular endpoint, this is the default value for + the environment variable) and "Auto" (auto switch to the default + mTLS endpoint if client SSL credentials is present). However, + the ``api_endpoint`` property takes precedence if provided. + (2) The ``client_cert_source`` property is used to provide client + SSL credentials for mutual TLS transport. If not provided, the + default SSL credentials will be used if present. Raises: - google.auth.exceptions.MutualTlsChannelError: If mutual TLS transport + google.auth.exceptions.MutualTLSChannelError: If mutual TLS transport creation failed for any reason. """ if isinstance(client_options, dict): client_options = ClientOptions.from_dict(client_options) + if client_options is None: + client_options = ClientOptions.ClientOptions() + + if transport is None and client_options.api_endpoint is None: + use_mtls_env = os.getenv("GOOGLE_API_USE_MTLS", "Never") + if use_mtls_env == "Never": + client_options.api_endpoint = self.DEFAULT_ENDPOINT + elif use_mtls_env == "Always": + client_options.api_endpoint = self.DEFAULT_MTLS_ENDPOINT + elif use_mtls_env == "Auto": + has_client_cert_source = ( + client_options.client_cert_source is not None + or mtls.has_default_client_cert_source() + ) + client_options.api_endpoint = ( + self.DEFAULT_MTLS_ENDPOINT if has_client_cert_source else self.DEFAULT_ENDPOINT + ) + else: + raise MutualTLSChannelError( + "Unsupported GOOGLE_API_USE_MTLS value. Accepted values: Never, Auto, Always" + ) # Save or instantiate the transport. # Ordinarily, we provide the transport, but allowing a custom transport @@ -169,38 +198,16 @@ class {{ service.client_name }}(metaclass={{ service.client_name }}Meta): raise ValueError('When providing a transport instance, ' 'provide its credentials directly.') self._transport = transport - elif client_options is None or ( - client_options.api_endpoint is None - and client_options.client_cert_source is None - ): - # Don't trigger mTLS if we get an empty ClientOptions. + elif isinstance(transport, str): Transport = type(self).get_transport_class(transport) self._transport = Transport( credentials=credentials, host=self.DEFAULT_ENDPOINT ) else: - # We have a non-empty ClientOptions. If client_cert_source is - # provided, trigger mTLS with user provided endpoint or the default - # mTLS endpoint. - if client_options.client_cert_source: - api_mtls_endpoint = ( - client_options.api_endpoint - if client_options.api_endpoint - else self.DEFAULT_MTLS_ENDPOINT - ) - else: - api_mtls_endpoint = None - - api_endpoint = ( - client_options.api_endpoint - if client_options.api_endpoint - else self.DEFAULT_ENDPOINT - ) - self._transport = {{ service.name }}GrpcTransport( credentials=credentials, - host=api_endpoint, - api_mtls_endpoint=api_mtls_endpoint, + host=client_options.api_endpoint, + api_mtls_endpoint=client_options.api_endpoint, client_cert_source=client_options.client_cert_source, ) diff --git a/gapic/ads-templates/%namespace/%name/%version/%sub/services/%service/transports/grpc.py.j2 b/gapic/ads-templates/%namespace/%name/%version/%sub/services/%service/transports/grpc.py.j2 index c42770f9a0..1632b77621 100644 --- a/gapic/ads-templates/%namespace/%name/%version/%sub/services/%service/transports/grpc.py.j2 +++ b/gapic/ads-templates/%namespace/%name/%version/%sub/services/%service/transports/grpc.py.j2 @@ -7,6 +7,7 @@ from google.api_core import grpc_helpers # type: ignore {%- if service.has_lro %} from google.api_core import operations_v1 # type: ignore {%- endif %} +from google import auth # type: ignore from google.auth import credentials # type: ignore from google.auth.transport.grpc import SslCredentials # type: ignore @@ -63,7 +64,7 @@ class {{ service.name }}GrpcTransport({{ service.name }}Transport): is None. Raises: - google.auth.exceptions.MutualTlsChannelError: If mutual TLS transport + google.auth.exceptions.MutualTLSChannelError: If mutual TLS transport creation failed for any reason. """ if channel: @@ -76,6 +77,9 @@ class {{ service.name }}GrpcTransport({{ service.name }}Transport): elif api_mtls_endpoint: host = api_mtls_endpoint if ":" in api_mtls_endpoint else api_mtls_endpoint + ":443" + if credentials is None: + credentials, _ = auth.default(scopes=self.AUTH_SCOPES) + # Create SSL credentials with client_cert_source or application # default SSL credentials. if client_cert_source: @@ -96,7 +100,7 @@ class {{ service.name }}GrpcTransport({{ service.name }}Transport): # Run the base constructor. super().__init__(host=host, credentials=credentials) - self._stubs = {} # type: Dict[str, Callable] + self._stubs = {} # type: Dict[str, Callable] @classmethod diff --git a/gapic/ads-templates/tests/unit/%name_%version/%sub/test_%service.py.j2 b/gapic/ads-templates/tests/unit/%name_%version/%sub/test_%service.py.j2 index 5e2e5e2870..1e2043b621 100644 --- a/gapic/ads-templates/tests/unit/%name_%version/%sub/test_%service.py.j2 +++ b/gapic/ads-templates/tests/unit/%name_%version/%sub/test_%service.py.j2 @@ -1,6 +1,7 @@ {% extends "_base.py.j2" %} {% block content %} +import os from unittest import mock import grpc @@ -11,6 +12,7 @@ import pytest {% filter sort_lines -%} from google import auth from google.auth import credentials +from google.auth.exceptions import MutualTLSChannelError from google.oauth2 import service_account from {{ (api.naming.module_namespace + (api.naming.versioned_module_name,) + service.meta.address.subpackage)|join(".") }}.services.{{ service.name|snake_case }} import {{ service.client_name }} from {{ (api.naming.module_namespace + (api.naming.versioned_module_name,) + service.meta.address.subpackage)|join(".") }}.services.{{ service.name|snake_case }} import transports @@ -63,6 +65,14 @@ def test_{{ service.client_name|snake_case }}_from_service_account_file(): {% if service.host %}assert client._transport._host == '{{ service.host }}{% if ":" not in service.host %}:443{% endif %}'{% endif %} +def test_{{ service.client_name|snake_case }}_get_transport_class(): + transport = {{ service.client_name }}.get_transport_class() + assert transport == transports.{{ service.name }}GrpcTransport + + transport = {{ service.client_name }}.get_transport_class("grpc") + assert transport == transports.{{ service.name }}GrpcTransport + + def test_{{ service.client_name|snake_case }}_client_options(): # Check that if channel is provided we won't create a new one. with mock.patch('{{ (api.naming.module_namespace + (api.naming.versioned_module_name,) + service.meta.address.subpackage)|join(".") }}.services.{{ service.name|snake_case }}.{{ service.client_name }}.get_transport_class') as gtc: @@ -72,58 +82,99 @@ def test_{{ service.client_name|snake_case }}_client_options(): client = {{ service.client_name }}(transport=transport) gtc.assert_not_called() - # Check mTLS is not triggered with empty client options. - options = client_options.ClientOptions() + # Check that if channel is provided via str we will create a new one. with mock.patch('{{ (api.naming.module_namespace + (api.naming.versioned_module_name,) + service.meta.address.subpackage)|join(".") }}.services.{{ service.name|snake_case }}.{{ service.client_name }}.get_transport_class') as gtc: - transport = gtc.return_value = mock.MagicMock() - client = {{ service.client_name }}(client_options=options) - transport.assert_called_once_with( - credentials=None, - host=client.DEFAULT_ENDPOINT, - ) + client = {{ service.client_name }}(transport="grpc") + gtc.assert_called() - # Check mTLS is not triggered if api_endpoint is provided but - # client_cert_source is None. + # Check the case api_endpoint is provided. options = client_options.ClientOptions(api_endpoint="squid.clam.whelk") with mock.patch('{{ (api.naming.module_namespace + (api.naming.versioned_module_name,) + service.meta.address.subpackage)|join(".") }}.services.{{ service.name|snake_case }}.transports.{{ service.name }}GrpcTransport.__init__') as grpc_transport: grpc_transport.return_value = None client = {{ service.client_name }}(client_options=options) grpc_transport.assert_called_once_with( - api_mtls_endpoint=None, + api_mtls_endpoint="squid.clam.whelk", client_cert_source=None, credentials=None, host="squid.clam.whelk", ) - # Check mTLS is triggered if client_cert_source is provided. - options = client_options.ClientOptions( - client_cert_source=client_cert_source_callback - ) + # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS is + # "Never". + os.environ["GOOGLE_API_USE_MTLS"] = "Never" with mock.patch('{{ (api.naming.module_namespace + (api.naming.versioned_module_name,) + service.meta.address.subpackage)|join(".") }}.services.{{ service.name|snake_case }}.transports.{{ service.name }}GrpcTransport.__init__') as grpc_transport: grpc_transport.return_value = None - client = {{ service.client_name }}(client_options=options) + client = {{ service.client_name }}() grpc_transport.assert_called_once_with( - api_mtls_endpoint=client.DEFAULT_MTLS_ENDPOINT, - client_cert_source=client_cert_source_callback, + api_mtls_endpoint=client.DEFAULT_ENDPOINT, + client_cert_source=None, credentials=None, host=client.DEFAULT_ENDPOINT, ) - # Check mTLS is triggered if api_endpoint and client_cert_source are provided. - options = client_options.ClientOptions( - api_endpoint="squid.clam.whelk", - client_cert_source=client_cert_source_callback - ) + # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS is + # "Always". + os.environ["GOOGLE_API_USE_MTLS"] = "Always" + with mock.patch('{{ (api.naming.module_namespace + (api.naming.versioned_module_name,) + service.meta.address.subpackage)|join(".") }}.services.{{ service.name|snake_case }}.transports.{{ service.name }}GrpcTransport.__init__') as grpc_transport: + grpc_transport.return_value = None + client = {{ service.client_name }}() + grpc_transport.assert_called_once_with( + api_mtls_endpoint=client.DEFAULT_MTLS_ENDPOINT, + client_cert_source=None, + credentials=None, + host=client.DEFAULT_MTLS_ENDPOINT, + ) + + # Check the case api_endpoint is not provided, GOOGLE_API_USE_MTLS is + # "Auto", and client_cert_source is provided. + os.environ["GOOGLE_API_USE_MTLS"] = "Auto" + options = client_options.ClientOptions(client_cert_source=client_cert_source_callback) with mock.patch('{{ (api.naming.module_namespace + (api.naming.versioned_module_name,) + service.meta.address.subpackage)|join(".") }}.services.{{ service.name|snake_case }}.transports.{{ service.name }}GrpcTransport.__init__') as grpc_transport: grpc_transport.return_value = None client = {{ service.client_name }}(client_options=options) grpc_transport.assert_called_once_with( - api_mtls_endpoint="squid.clam.whelk", + api_mtls_endpoint=client.DEFAULT_MTLS_ENDPOINT, client_cert_source=client_cert_source_callback, credentials=None, - host="squid.clam.whelk", + host=client.DEFAULT_MTLS_ENDPOINT, ) + # Check the case api_endpoint is not provided, GOOGLE_API_USE_MTLS is + # "Auto", and default_client_cert_source is provided. + os.environ["GOOGLE_API_USE_MTLS"] = "Auto" + with mock.patch('{{ (api.naming.module_namespace + (api.naming.versioned_module_name,) + service.meta.address.subpackage)|join(".") }}.services.{{ service.name|snake_case }}.transports.{{ service.name }}GrpcTransport.__init__') as grpc_transport: + with mock.patch('google.auth.transport.mtls.has_default_client_cert_source', return_value=True): + grpc_transport.return_value = None + client = {{ service.client_name }}() + grpc_transport.assert_called_once_with( + api_mtls_endpoint=client.DEFAULT_MTLS_ENDPOINT, + client_cert_source=None, + credentials=None, + host=client.DEFAULT_MTLS_ENDPOINT, + ) + + # Check the case api_endpoint is not provided, GOOGLE_API_USE_MTLS is + # "Auto", but client_cert_source and default_client_cert_source are None. + os.environ["GOOGLE_API_USE_MTLS"] = "Auto" + with mock.patch('{{ (api.naming.module_namespace + (api.naming.versioned_module_name,) + service.meta.address.subpackage)|join(".") }}.services.{{ service.name|snake_case }}.transports.{{ service.name }}GrpcTransport.__init__') as grpc_transport: + with mock.patch('google.auth.transport.mtls.has_default_client_cert_source', return_value=False): + grpc_transport.return_value = None + client = {{ service.client_name }}() + grpc_transport.assert_called_once_with( + api_mtls_endpoint=client.DEFAULT_ENDPOINT, + client_cert_source=None, + credentials=None, + host=client.DEFAULT_ENDPOINT, + ) + + # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS has + # unsupported value. + os.environ["GOOGLE_API_USE_MTLS"] = "Unsupported" + with pytest.raises(MutualTLSChannelError): + client = {{ service.client_name }}() + + del os.environ["GOOGLE_API_USE_MTLS"] + def test_{{ service.client_name|snake_case }}_client_options_from_dict(): with mock.patch('{{ (api.naming.module_namespace + (api.naming.versioned_module_name,) + service.meta.address.subpackage)|join(".") }}.services.{{ service.name|snake_case }}.transports.{{ service.name }}GrpcTransport.__init__') as grpc_transport: @@ -132,7 +183,7 @@ def test_{{ service.client_name|snake_case }}_client_options_from_dict(): client_options={'api_endpoint': 'squid.clam.whelk'} ) grpc_transport.assert_called_once_with( - api_mtls_endpoint=None, + api_mtls_endpoint="squid.clam.whelk", client_cert_source=None, credentials=None, host="squid.clam.whelk", @@ -490,12 +541,24 @@ def test_{{ service.name|snake_case }}_auth_adc(): )) +def test_{{ service.name|snake_case }}_transport_auth_adc(): + # If credentials and host are not provided, the transport class should use + # ADC credentials. + with mock.patch.object(auth, 'default') as adc: + adc.return_value = (credentials.AnonymousCredentials(), None) + transports.{{ service.name }}GrpcTransport(host="squid.clam.whelk") + adc.assert_called_once_with(scopes=( + {%- for scope in service.oauth_scopes %} + '{{ scope }}', + {%- endfor %} + )) + + def test_{{ service.name|snake_case }}_host_no_port(): {% with host = (service.host|default('localhost', true)).split(':')[0] -%} client = {{ service.client_name }}( credentials=credentials.AnonymousCredentials(), client_options=client_options.ClientOptions(api_endpoint='{{ host }}'), - transport='grpc', ) assert client._transport._host == '{{ host }}:443' {% endwith %} @@ -506,7 +569,6 @@ def test_{{ service.name|snake_case }}_host_with_port(): client = {{ service.client_name }}( credentials=credentials.AnonymousCredentials(), client_options=client_options.ClientOptions(api_endpoint='{{ host }}:8000'), - transport='grpc', ) assert client._transport._host == '{{ host }}:8000' {% endwith %} diff --git a/gapic/templates/%namespace/%name_%version/%sub/services/%service/client.py.j2 b/gapic/templates/%namespace/%name_%version/%sub/services/%service/client.py.j2 index a3c9d5b9b3..1915722a0c 100644 --- a/gapic/templates/%namespace/%name_%version/%sub/services/%service/client.py.j2 +++ b/gapic/templates/%namespace/%name_%version/%sub/services/%service/client.py.j2 @@ -2,6 +2,7 @@ {% block content %} from collections import OrderedDict +import os import re from typing import Callable, Dict, {% if service.any_server_streaming %}Iterable, {% endif %}{% if service.any_client_streaming %}Iterator, {% endif %}Sequence, Tuple, Type, Union import pkg_resources @@ -11,6 +12,8 @@ from google.api_core import exceptions # type: ignore from google.api_core import gapic_v1 # type: ignore from google.api_core import retry as retries # type: ignore from google.auth import credentials # type: ignore +from google.auth.transport import mtls # type: ignore +from google.auth.exceptions import MutualTLSChannelError # type: ignore from google.oauth2 import service_account # type: ignore {% filter sort_lines -%} @@ -144,21 +147,47 @@ class {{ service.client_name }}(metaclass={{ service.client_name }}Meta): transport (Union[str, ~.{{ service.name }}Transport]): The transport to use. If set to None, a transport is chosen automatically. - client_options (ClientOptions): Custom options for the client. + client_options (ClientOptions): Custom options for the client. It + won't take effect unless ``transport`` is None. (1) The ``api_endpoint`` property can be used to override the - default endpoint provided by the client. - (2) If ``transport`` argument is None, ``client_options`` can be - used to create a mutual TLS transport. If ``client_cert_source`` - is provided, mutual TLS transport will be created with the given - ``api_endpoint`` or the default mTLS endpoint, and the client - SSL credentials obtained from ``client_cert_source``. + default endpoint provided by the client. GOOGLE_API_USE_MTLS + environment variable can also be used to override the endpoint: + "Always" (always use the default mTLS endpoint), "Never" (always + use the default regular endpoint, this is the default value for + the environment variable) and "Auto" (auto switch to the default + mTLS endpoint if client SSL credentials is present). However, + the ``api_endpoint`` property takes precedence if provided. + (2) The ``client_cert_source`` property is used to provide client + SSL credentials for mutual TLS transport. If not provided, the + default SSL credentials will be used if present. Raises: - google.auth.exceptions.MutualTlsChannelError: If mutual TLS transport + google.auth.exceptions.MutualTLSChannelError: If mutual TLS transport creation failed for any reason. """ if isinstance(client_options, dict): client_options = ClientOptions.from_dict(client_options) + if client_options is None: + client_options = ClientOptions.ClientOptions() + + if transport is None and client_options.api_endpoint is None: + use_mtls_env = os.getenv("GOOGLE_API_USE_MTLS", "Never") + if use_mtls_env == "Never": + client_options.api_endpoint = self.DEFAULT_ENDPOINT + elif use_mtls_env == "Always": + client_options.api_endpoint = self.DEFAULT_MTLS_ENDPOINT + elif use_mtls_env == "Auto": + has_client_cert_source = ( + client_options.client_cert_source is not None + or mtls.has_default_client_cert_source() + ) + client_options.api_endpoint = ( + self.DEFAULT_MTLS_ENDPOINT if has_client_cert_source else self.DEFAULT_ENDPOINT + ) + else: + raise MutualTLSChannelError( + "Unsupported GOOGLE_API_USE_MTLS value. Accepted values: Never, Auto, Always" + ) # Save or instantiate the transport. # Ordinarily, we provide the transport, but allowing a custom transport @@ -169,38 +198,16 @@ class {{ service.client_name }}(metaclass={{ service.client_name }}Meta): raise ValueError('When providing a transport instance, ' 'provide its credentials directly.') self._transport = transport - elif client_options is None or ( - client_options.api_endpoint is None - and client_options.client_cert_source is None - ): - # Don't trigger mTLS if we get an empty ClientOptions. + elif isinstance(transport, str): Transport = type(self).get_transport_class(transport) self._transport = Transport( credentials=credentials, host=self.DEFAULT_ENDPOINT ) else: - # We have a non-empty ClientOptions. If client_cert_source is - # provided, trigger mTLS with user provided endpoint or the default - # mTLS endpoint. - if client_options.client_cert_source: - api_mtls_endpoint = ( - client_options.api_endpoint - if client_options.api_endpoint - else self.DEFAULT_MTLS_ENDPOINT - ) - else: - api_mtls_endpoint = None - - api_endpoint = ( - client_options.api_endpoint - if client_options.api_endpoint - else self.DEFAULT_ENDPOINT - ) - self._transport = {{ service.name }}GrpcTransport( credentials=credentials, - host=api_endpoint, - api_mtls_endpoint=api_mtls_endpoint, + host=client_options.api_endpoint, + api_mtls_endpoint=client_options.api_endpoint, client_cert_source=client_options.client_cert_source, ) diff --git a/gapic/templates/%namespace/%name_%version/%sub/services/%service/transports/grpc.py.j2 b/gapic/templates/%namespace/%name_%version/%sub/services/%service/transports/grpc.py.j2 index eb47dbdc52..1632b77621 100644 --- a/gapic/templates/%namespace/%name_%version/%sub/services/%service/transports/grpc.py.j2 +++ b/gapic/templates/%namespace/%name_%version/%sub/services/%service/transports/grpc.py.j2 @@ -7,6 +7,7 @@ from google.api_core import grpc_helpers # type: ignore {%- if service.has_lro %} from google.api_core import operations_v1 # type: ignore {%- endif %} +from google import auth # type: ignore from google.auth import credentials # type: ignore from google.auth.transport.grpc import SslCredentials # type: ignore @@ -63,7 +64,7 @@ class {{ service.name }}GrpcTransport({{ service.name }}Transport): is None. Raises: - google.auth.exceptions.MutualTlsChannelError: If mutual TLS transport + google.auth.exceptions.MutualTLSChannelError: If mutual TLS transport creation failed for any reason. """ if channel: @@ -76,6 +77,9 @@ class {{ service.name }}GrpcTransport({{ service.name }}Transport): elif api_mtls_endpoint: host = api_mtls_endpoint if ":" in api_mtls_endpoint else api_mtls_endpoint + ":443" + if credentials is None: + credentials, _ = auth.default(scopes=self.AUTH_SCOPES) + # Create SSL credentials with client_cert_source or application # default SSL credentials. if client_cert_source: diff --git a/gapic/templates/tests/unit/%name_%version/%sub/test_%service.py.j2 b/gapic/templates/tests/unit/%name_%version/%sub/test_%service.py.j2 index 3a1aa0e350..a613c6753f 100644 --- a/gapic/templates/tests/unit/%name_%version/%sub/test_%service.py.j2 +++ b/gapic/templates/tests/unit/%name_%version/%sub/test_%service.py.j2 @@ -1,6 +1,7 @@ {% extends "_base.py.j2" %} {% block content %} +import os from unittest import mock import grpc @@ -11,6 +12,7 @@ import pytest {% filter sort_lines -%} from google import auth from google.auth import credentials +from google.auth.exceptions import MutualTLSChannelError from google.oauth2 import service_account from {{ (api.naming.module_namespace + (api.naming.versioned_module_name,) + service.meta.address.subpackage)|join(".") }}.services.{{ service.name|snake_case }} import {{ service.client_name }} from {{ (api.naming.module_namespace + (api.naming.versioned_module_name,) + service.meta.address.subpackage)|join(".") }}.services.{{ service.name|snake_case }} import transports @@ -63,6 +65,14 @@ def test_{{ service.client_name|snake_case }}_from_service_account_file(): {% if service.host %}assert client._transport._host == '{{ service.host }}{% if ":" not in service.host %}:443{% endif %}'{% endif %} +def test_{{ service.client_name|snake_case }}_get_transport_class(): + transport = {{ service.client_name }}.get_transport_class() + assert transport == transports.{{ service.name }}GrpcTransport + + transport = {{ service.client_name }}.get_transport_class("grpc") + assert transport == transports.{{ service.name }}GrpcTransport + + def test_{{ service.client_name|snake_case }}_client_options(): # Check that if channel is provided we won't create a new one. with mock.patch('{{ (api.naming.module_namespace + (api.naming.versioned_module_name,) + service.meta.address.subpackage)|join(".") }}.services.{{ service.name|snake_case }}.{{ service.client_name }}.get_transport_class') as gtc: @@ -72,58 +82,100 @@ def test_{{ service.client_name|snake_case }}_client_options(): client = {{ service.client_name }}(transport=transport) gtc.assert_not_called() - # Check mTLS is not triggered with empty client options. - options = client_options.ClientOptions() + # Check that if channel is provided via str we will create a new one. with mock.patch('{{ (api.naming.module_namespace + (api.naming.versioned_module_name,) + service.meta.address.subpackage)|join(".") }}.services.{{ service.name|snake_case }}.{{ service.client_name }}.get_transport_class') as gtc: - transport = gtc.return_value = mock.MagicMock() - client = {{ service.client_name }}(client_options=options) - transport.assert_called_once_with( - credentials=None, - host=client.DEFAULT_ENDPOINT, - ) + client = {{ service.client_name }}(transport="grpc") + gtc.assert_called() - # Check mTLS is not triggered if api_endpoint is provided but - # client_cert_source is None. + # Check the case api_endpoint is provided. options = client_options.ClientOptions(api_endpoint="squid.clam.whelk") with mock.patch('{{ (api.naming.module_namespace + (api.naming.versioned_module_name,) + service.meta.address.subpackage)|join(".") }}.services.{{ service.name|snake_case }}.transports.{{ service.name }}GrpcTransport.__init__') as grpc_transport: grpc_transport.return_value = None client = {{ service.client_name }}(client_options=options) grpc_transport.assert_called_once_with( - api_mtls_endpoint=None, + api_mtls_endpoint="squid.clam.whelk", client_cert_source=None, credentials=None, host="squid.clam.whelk", ) - # Check mTLS is triggered if client_cert_source is provided. - options = client_options.ClientOptions( - client_cert_source=client_cert_source_callback - ) + # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS is + # "Never". + os.environ["GOOGLE_API_USE_MTLS"] = "Never" with mock.patch('{{ (api.naming.module_namespace + (api.naming.versioned_module_name,) + service.meta.address.subpackage)|join(".") }}.services.{{ service.name|snake_case }}.transports.{{ service.name }}GrpcTransport.__init__') as grpc_transport: grpc_transport.return_value = None - client = {{ service.client_name }}(client_options=options) + client = {{ service.client_name }}() grpc_transport.assert_called_once_with( - api_mtls_endpoint=client.DEFAULT_MTLS_ENDPOINT, - client_cert_source=client_cert_source_callback, + api_mtls_endpoint=client.DEFAULT_ENDPOINT, + client_cert_source=None, credentials=None, host=client.DEFAULT_ENDPOINT, ) - # Check mTLS is triggered if api_endpoint and client_cert_source are provided. - options = client_options.ClientOptions( - api_endpoint="squid.clam.whelk", - client_cert_source=client_cert_source_callback - ) + # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS is + # "Always". + os.environ["GOOGLE_API_USE_MTLS"] = "Always" + with mock.patch('{{ (api.naming.module_namespace + (api.naming.versioned_module_name,) + service.meta.address.subpackage)|join(".") }}.services.{{ service.name|snake_case }}.transports.{{ service.name }}GrpcTransport.__init__') as grpc_transport: + grpc_transport.return_value = None + client = {{ service.client_name }}() + grpc_transport.assert_called_once_with( + api_mtls_endpoint=client.DEFAULT_MTLS_ENDPOINT, + client_cert_source=None, + credentials=None, + host=client.DEFAULT_MTLS_ENDPOINT, + ) + + # Check the case api_endpoint is not provided, GOOGLE_API_USE_MTLS is + # "Auto", and client_cert_source is provided. + os.environ["GOOGLE_API_USE_MTLS"] = "Auto" + options = client_options.ClientOptions(client_cert_source=client_cert_source_callback) with mock.patch('{{ (api.naming.module_namespace + (api.naming.versioned_module_name,) + service.meta.address.subpackage)|join(".") }}.services.{{ service.name|snake_case }}.transports.{{ service.name }}GrpcTransport.__init__') as grpc_transport: grpc_transport.return_value = None client = {{ service.client_name }}(client_options=options) grpc_transport.assert_called_once_with( - api_mtls_endpoint="squid.clam.whelk", + api_mtls_endpoint=client.DEFAULT_MTLS_ENDPOINT, client_cert_source=client_cert_source_callback, credentials=None, - host="squid.clam.whelk", + host=client.DEFAULT_MTLS_ENDPOINT, ) + # Check the case api_endpoint is not provided, GOOGLE_API_USE_MTLS is + # "Auto", and default_client_cert_source is provided. + os.environ["GOOGLE_API_USE_MTLS"] = "Auto" + with mock.patch('{{ (api.naming.module_namespace + (api.naming.versioned_module_name,) + service.meta.address.subpackage)|join(".") }}.services.{{ service.name|snake_case }}.transports.{{ service.name }}GrpcTransport.__init__') as grpc_transport: + with mock.patch('google.auth.transport.mtls.has_default_client_cert_source', return_value=True): + grpc_transport.return_value = None + client = {{ service.client_name }}() + grpc_transport.assert_called_once_with( + api_mtls_endpoint=client.DEFAULT_MTLS_ENDPOINT, + client_cert_source=None, + credentials=None, + host=client.DEFAULT_MTLS_ENDPOINT, + ) + + # Check the case api_endpoint is not provided, GOOGLE_API_USE_MTLS is + # "Auto", but client_cert_source and default_client_cert_source are None. + os.environ["GOOGLE_API_USE_MTLS"] = "Auto" + with mock.patch('{{ (api.naming.module_namespace + (api.naming.versioned_module_name,) + service.meta.address.subpackage)|join(".") }}.services.{{ service.name|snake_case }}.transports.{{ service.name }}GrpcTransport.__init__') as grpc_transport: + with mock.patch('google.auth.transport.mtls.has_default_client_cert_source', return_value=False): + grpc_transport.return_value = None + client = {{ service.client_name }}() + grpc_transport.assert_called_once_with( + api_mtls_endpoint=client.DEFAULT_ENDPOINT, + client_cert_source=None, + credentials=None, + host=client.DEFAULT_ENDPOINT, + ) + + # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS has + # unsupported value. + os.environ["GOOGLE_API_USE_MTLS"] = "Unsupported" + with pytest.raises(MutualTLSChannelError): + client = {{ service.client_name }}() + + del os.environ["GOOGLE_API_USE_MTLS"] + + def test_{{ service.client_name|snake_case }}_client_options_from_dict(): with mock.patch('{{ (api.naming.module_namespace + (api.naming.versioned_module_name,) + service.meta.address.subpackage)|join(".") }}.services.{{ service.name|snake_case }}.transports.{{ service.name }}GrpcTransport.__init__') as grpc_transport: grpc_transport.return_value = None @@ -131,7 +183,7 @@ def test_{{ service.client_name|snake_case }}_client_options_from_dict(): client_options={'api_endpoint': 'squid.clam.whelk'} ) grpc_transport.assert_called_once_with( - api_mtls_endpoint=None, + api_mtls_endpoint="squid.clam.whelk", client_cert_source=None, credentials=None, host="squid.clam.whelk", @@ -557,12 +609,24 @@ def test_{{ service.name|snake_case }}_auth_adc(): )) +def test_{{ service.name|snake_case }}_transport_auth_adc(): + # If credentials and host are not provided, the transport class should use + # ADC credentials. + with mock.patch.object(auth, 'default') as adc: + adc.return_value = (credentials.AnonymousCredentials(), None) + transports.{{ service.name }}GrpcTransport(host="squid.clam.whelk") + adc.assert_called_once_with(scopes=( + {%- for scope in service.oauth_scopes %} + '{{ scope }}', + {%- endfor %} + )) + + def test_{{ service.name|snake_case }}_host_no_port(): {% with host = (service.host|default('localhost', true)).split(':')[0] -%} client = {{ service.client_name }}( credentials=credentials.AnonymousCredentials(), client_options=client_options.ClientOptions(api_endpoint='{{ host }}'), - transport='grpc', ) assert client._transport._host == '{{ host }}:443' {% endwith %} @@ -573,7 +637,6 @@ def test_{{ service.name|snake_case }}_host_with_port(): client = {{ service.client_name }}( credentials=credentials.AnonymousCredentials(), client_options=client_options.ClientOptions(api_endpoint='{{ host }}:8000'), - transport='grpc', ) assert client._transport._host == '{{ host }}:8000' {% endwith %}