Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add helper func to for default encrypted cert #514

Merged
merged 4 commits into from
May 28, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 42 additions & 0 deletions google/auth/transport/mtls.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,3 +58,45 @@ def callback():
return cert_bytes, key_bytes

return callback


def default_client_encrypted_cert_source(cert_path, key_path):
"""Get a callback which returns the default encrpyted client SSL credentials.

Args:
cert_path (str): The cert file path. The default client certificate will
be written to this file when the returned callback is called.
key_path (str): The key file path. The default encrypted client key will
be written to this file when the returned callback is called.

Returns:
Callable[[], [str, str, bytes]]: A callback which generates the default
client certificate, encrpyted private key and passphrase. It writes
the certificate and private key into the cert_path and key_path, and
returns the cert_path, key_path and passphrase bytes.

Raises:
google.auth.exceptions.DefaultClientCertSourceError: If any problem
occurs when loading or saving the client certificate and key.
"""
if not has_default_client_cert_source():
raise exceptions.MutualTLSChannelError(
"Default client encrypted cert source doesn't exist"
)

def callback():
try:
_, cert_bytes, key_bytes, passphrase_bytes = _mtls_helper.get_client_ssl_credentials(
generate_encrypted_key=True
)
with open(cert_path, "wb") as cert_file:
cert_file.write(cert_bytes)
with open(key_path, "wb") as key_file:
key_file.write(key_bytes)
except (exceptions.ClientCertError, OSError) as caught_exc:
new_exc = exceptions.MutualTLSChannelError(caught_exc)
six.raise_from(new_exc, caught_exc)

return cert_path, key_path, passphrase_bytes

return callback
28 changes: 28 additions & 0 deletions tests/transport/test_mtls.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,3 +53,31 @@ def test_default_client_cert_source(
callback = mtls.default_client_cert_source()
with pytest.raises(exceptions.MutualTLSChannelError):
callback()


@mock.patch(
"google.auth.transport._mtls_helper.get_client_ssl_credentials", autospec=True
)
@mock.patch("google.auth.transport.mtls.has_default_client_cert_source", autospec=True)
def test_default_client_encrypted_cert_source(
has_default_client_cert_source, get_client_ssl_credentials
):
# Test default client cert source doesn't exist.
has_default_client_cert_source.return_value = False
with pytest.raises(exceptions.MutualTLSChannelError):
mtls.default_client_encrypted_cert_source("cert_path", "key_path")

# The following tests will assume default client cert source exists.
has_default_client_cert_source.return_value = True

# Test good callback.
get_client_ssl_credentials.return_value = (True, b"cert", b"key", b"passphrase")
callback = mtls.default_client_encrypted_cert_source("cert_path", "key_path")
with mock.patch("{}.open".format(__name__), return_value=mock.MagicMock()):
assert callback() == ("cert_path", "key_path", b"passphrase")

# Test bad callback which throws exception.
get_client_ssl_credentials.side_effect = exceptions.ClientCertError()
callback = mtls.default_client_encrypted_cert_source("cert_path", "key_path")
with pytest.raises(exceptions.MutualTLSChannelError):
callback()