Skip to content

Commit

Permalink
feat: add experimental GDCH support (#1022)
Browse files Browse the repository at this point in the history
* feat: add experimental GDCH support

* address comments

* 🦉 Updates from OwlBot post-processor

See https://github.com/googleapis/repo-automation-bots/blob/main/packages/owl-bot/README.md

* remove quota project id

Co-authored-by: Owl Bot <gcf-owl-bot[bot]@users.noreply.github.com>
  • Loading branch information
arithmetic1728 and gcf-owl-bot[bot] committed May 10, 2022
1 parent 62daa73 commit 5367aac
Show file tree
Hide file tree
Showing 7 changed files with 578 additions and 20 deletions.
46 changes: 45 additions & 1 deletion google/auth/_default.py
Expand Up @@ -36,11 +36,13 @@
_SERVICE_ACCOUNT_TYPE = "service_account"
_EXTERNAL_ACCOUNT_TYPE = "external_account"
_IMPERSONATED_SERVICE_ACCOUNT_TYPE = "impersonated_service_account"
_GDCH_SERVICE_ACCOUNT_TYPE = "gdch_service_account"
_VALID_TYPES = (
_AUTHORIZED_USER_TYPE,
_SERVICE_ACCOUNT_TYPE,
_EXTERNAL_ACCOUNT_TYPE,
_IMPERSONATED_SERVICE_ACCOUNT_TYPE,
_GDCH_SERVICE_ACCOUNT_TYPE,
)

# Help message when no credentials can be found.
Expand Down Expand Up @@ -134,6 +136,8 @@ def load_credentials_from_file(
def _load_credentials_from_info(
filename, info, scopes, default_scopes, quota_project_id, request
):
from google.auth.credentials import CredentialsWithQuotaProject

credential_type = info.get("type")

if credential_type == _AUTHORIZED_USER_TYPE:
Expand All @@ -158,14 +162,17 @@ def _load_credentials_from_info(
credentials, project_id = _get_impersonated_service_account_credentials(
filename, info, scopes
)
elif credential_type == _GDCH_SERVICE_ACCOUNT_TYPE:
credentials, project_id = _get_gdch_service_account_credentials(info)
else:
raise exceptions.DefaultCredentialsError(
"The file {file} does not have a valid type. "
"Type is {type}, expected one of {valid_types}.".format(
file=filename, type=credential_type, valid_types=_VALID_TYPES
)
)
credentials = _apply_quota_project_id(credentials, quota_project_id)
if isinstance(credentials, CredentialsWithQuotaProject):
credentials = _apply_quota_project_id(credentials, quota_project_id)
return credentials, project_id


Expand Down Expand Up @@ -430,6 +437,36 @@ def _get_impersonated_service_account_credentials(filename, info, scopes):
return credentials, None


def _get_gdch_service_account_credentials(info):
from google.oauth2 import gdch_credentials

k8s_ca_cert_path = info.get("k8s_ca_cert_path")
k8s_cert_path = info.get("k8s_cert_path")
k8s_key_path = info.get("k8s_key_path")
k8s_token_endpoint = info.get("k8s_token_endpoint")
ais_ca_cert_path = info.get("ais_ca_cert_path")
ais_token_endpoint = info.get("ais_token_endpoint")

format_version = info.get("format_version")
if format_version != "v1":
raise exceptions.DefaultCredentialsError(
"format_version is not provided or unsupported. Supported version is: v1"
)

return (
gdch_credentials.ServiceAccountCredentials(
k8s_ca_cert_path,
k8s_cert_path,
k8s_key_path,
k8s_token_endpoint,
ais_ca_cert_path,
ais_token_endpoint,
None,
),
None,
)


def _apply_quota_project_id(credentials, quota_project_id):
if quota_project_id:
credentials = credentials.with_quota_project(quota_project_id)
Expand Down Expand Up @@ -465,6 +502,11 @@ def default(scopes=None, request=None, quota_project_id=None, default_scopes=Non
endpoint.
The project ID returned in this case is the one corresponding to the
underlying workload identity pool resource if determinable.
If the environment variable is set to the path of a valid GDCH service
account JSON file (`Google Distributed Cloud Hosted`_), then a GDCH
credential will be returned. The project ID returned is None unless it
is set via `GOOGLE_CLOUD_PROJECT` environment variable.
2. If the `Google Cloud SDK`_ is installed and has application default
credentials set they are loaded and returned.
Expand Down Expand Up @@ -499,6 +541,8 @@ def default(scopes=None, request=None, quota_project_id=None, default_scopes=Non
.. _Metadata Service: https://cloud.google.com/compute/docs\
/storing-retrieving-metadata
.. _Cloud Run: https://cloud.google.com/run
.. _Google Distributed Cloud Hosted: https://cloud.google.com/blog/topics\
/hybrid-cloud/announcing-google-distributed-cloud-edge-and-hosted
Example::
Expand Down
84 changes: 66 additions & 18 deletions google/oauth2/_client.py
Expand Up @@ -44,11 +44,13 @@ def _handle_error_response(response_data):
"""Translates an error response into an exception.
Args:
response_data (Mapping): The decoded response data.
response_data (Mapping | str): The decoded response data.
Raises:
google.auth.exceptions.RefreshError: The errors contained in response_data.
"""
if isinstance(response_data, six.string_types):
raise exceptions.RefreshError(response_data)
try:
error_details = "{}: {}".format(
response_data["error"], response_data.get("error_description")
Expand Down Expand Up @@ -79,7 +81,13 @@ def _parse_expiry(response_data):


def _token_endpoint_request_no_throw(
request, token_uri, body, access_token=None, use_json=False
request,
token_uri,
body,
access_token=None,
use_json=False,
expected_status_code=http_client.OK,
**kwargs
):
"""Makes a request to the OAuth 2.0 authorization server's token endpoint.
This function doesn't throw on response errors.
Expand All @@ -93,6 +101,16 @@ def _token_endpoint_request_no_throw(
access_token (Optional(str)): The access token needed to make the request.
use_json (Optional(bool)): Use urlencoded format or json format for the
content type. The default value is False.
expected_status_code (Optional(int)): The expected the status code of
the token response. The default value is 200. We may expect other
status code like 201 for GDCH credentials.
kwargs: Additional arguments passed on to the request method. The
kwargs will be passed to `requests.request` method, see:
https://docs.python-requests.org/en/latest/api/#requests.request.
For example, you can use `cert=("cert_pem_path", "key_pem_path")`
to set up client side SSL certificate, and use
`verify="ca_bundle_path"` to set up the CA certificates for sever
side SSL certificate verification.
Returns:
Tuple(bool, Mapping[str, str]): A boolean indicating if the request is
Expand All @@ -112,32 +130,46 @@ def _token_endpoint_request_no_throw(
# retry to fetch token for maximum of two times if any internal failure
# occurs.
while True:
response = request(method="POST", url=token_uri, headers=headers, body=body)
response = request(
method="POST", url=token_uri, headers=headers, body=body, **kwargs
)
response_body = (
response.data.decode("utf-8")
if hasattr(response.data, "decode")
else response.data
)
response_data = json.loads(response_body)

if response.status == http_client.OK:
if response.status == expected_status_code:
# response_body should be a JSON
response_data = json.loads(response_body)
break
else:
error_desc = response_data.get("error_description") or ""
error_code = response_data.get("error") or ""
if (
any(e == "internal_failure" for e in (error_code, error_desc))
and retry < 1
):
retry += 1
continue
return response.status == http_client.OK, response_data

return response.status == http_client.OK, response_data
# For a failed response, response_body could be a string
try:
response_data = json.loads(response_body)
error_desc = response_data.get("error_description") or ""
error_code = response_data.get("error") or ""
if (
any(e == "internal_failure" for e in (error_code, error_desc))
and retry < 1
):
retry += 1
continue
except ValueError:
response_data = response_body
return False, response_data

return response.status == expected_status_code, response_data


def _token_endpoint_request(
request, token_uri, body, access_token=None, use_json=False
request,
token_uri,
body,
access_token=None,
use_json=False,
expected_status_code=http_client.OK,
**kwargs
):
"""Makes a request to the OAuth 2.0 authorization server's token endpoint.
Expand All @@ -150,6 +182,16 @@ def _token_endpoint_request(
access_token (Optional(str)): The access token needed to make the request.
use_json (Optional(bool)): Use urlencoded format or json format for the
content type. The default value is False.
expected_status_code (Optional(int)): The expected the status code of
the token response. The default value is 200. We may expect other
status code like 201 for GDCH credentials.
kwargs: Additional arguments passed on to the request method. The
kwargs will be passed to `requests.request` method, see:
https://docs.python-requests.org/en/latest/api/#requests.request.
For example, you can use `cert=("cert_pem_path", "key_pem_path")`
to set up client side SSL certificate, and use
`verify="ca_bundle_path"` to set up the CA certificates for sever
side SSL certificate verification.
Returns:
Mapping[str, str]: The JSON-decoded response data.
Expand All @@ -159,7 +201,13 @@ def _token_endpoint_request(
an error.
"""
response_status_ok, response_data = _token_endpoint_request_no_throw(
request, token_uri, body, access_token=access_token, use_json=use_json
request,
token_uri,
body,
access_token=access_token,
use_json=use_json,
expected_status_code=expected_status_code,
**kwargs
)
if not response_status_ok:
_handle_error_response(response_data)
Expand Down

0 comments on commit 5367aac

Please sign in to comment.