diff --git a/google/auth/transport/mtls.py b/google/auth/transport/mtls.py index 063b26504..5b742306b 100644 --- a/google/auth/transport/mtls.py +++ b/google/auth/transport/mtls.py @@ -58,3 +58,45 @@ def callback(): return cert_bytes, key_bytes return callback + + +def default_client_encrypted_cert_source(cert_path, key_path): + """Get a callback which returns the default encrpyted client SSL credentials. + + Args: + cert_path (str): The cert file path. The default client certificate will + be written to this file when the returned callback is called. + key_path (str): The key file path. The default encrypted client key will + be written to this file when the returned callback is called. + + Returns: + Callable[[], [str, str, bytes]]: A callback which generates the default + client certificate, encrpyted private key and passphrase. It writes + the certificate and private key into the cert_path and key_path, and + returns the cert_path, key_path and passphrase bytes. + + Raises: + google.auth.exceptions.DefaultClientCertSourceError: If any problem + occurs when loading or saving the client certificate and key. + """ + if not has_default_client_cert_source(): + raise exceptions.MutualTLSChannelError( + "Default client encrypted cert source doesn't exist" + ) + + def callback(): + try: + _, cert_bytes, key_bytes, passphrase_bytes = _mtls_helper.get_client_ssl_credentials( + generate_encrypted_key=True + ) + with open(cert_path, "wb") as cert_file: + cert_file.write(cert_bytes) + with open(key_path, "wb") as key_file: + key_file.write(key_bytes) + except (exceptions.ClientCertError, OSError) as caught_exc: + new_exc = exceptions.MutualTLSChannelError(caught_exc) + six.raise_from(new_exc, caught_exc) + + return cert_path, key_path, passphrase_bytes + + return callback diff --git a/tests/transport/test_mtls.py b/tests/transport/test_mtls.py index d3bc3915a..ff70bb3c2 100644 --- a/tests/transport/test_mtls.py +++ b/tests/transport/test_mtls.py @@ -53,3 +53,31 @@ def test_default_client_cert_source( callback = mtls.default_client_cert_source() with pytest.raises(exceptions.MutualTLSChannelError): callback() + + +@mock.patch( + "google.auth.transport._mtls_helper.get_client_ssl_credentials", autospec=True +) +@mock.patch("google.auth.transport.mtls.has_default_client_cert_source", autospec=True) +def test_default_client_encrypted_cert_source( + has_default_client_cert_source, get_client_ssl_credentials +): + # Test default client cert source doesn't exist. + has_default_client_cert_source.return_value = False + with pytest.raises(exceptions.MutualTLSChannelError): + mtls.default_client_encrypted_cert_source("cert_path", "key_path") + + # The following tests will assume default client cert source exists. + has_default_client_cert_source.return_value = True + + # Test good callback. + get_client_ssl_credentials.return_value = (True, b"cert", b"key", b"passphrase") + callback = mtls.default_client_encrypted_cert_source("cert_path", "key_path") + with mock.patch("{}.open".format(__name__), return_value=mock.MagicMock()): + assert callback() == ("cert_path", "key_path", b"passphrase") + + # Test bad callback which throws exception. + get_client_ssl_credentials.side_effect = exceptions.ClientCertError() + callback = mtls.default_client_encrypted_cert_source("cert_path", "key_path") + with pytest.raises(exceptions.MutualTLSChannelError): + callback()