diff --git a/docs/user-guide.rst b/docs/user-guide.rst index 0cb119127..05516393e 100644 --- a/docs/user-guide.rst +++ b/docs/user-guide.rst @@ -212,7 +212,7 @@ There is a separate library, `google-auth-oauthlib`_, that has some helpers for integrating with `requests-oauthlib`_ to provide support for obtaining user credentials. You can use :func:`google_auth_oauthlib.helpers.credentials_from_session` to obtain -:class:`google.oauth2.credentials.Credentials` from a +:class:`google.oauth2.credentials.Credentials` from a :class:`requests_oauthlib.OAuth2Session` as above:: from google_auth_oauthlib.helpers import credentials_from_session @@ -459,9 +459,9 @@ Error responses must include both the ``code`` and ``message`` fields. The library will populate the following environment variables when the executable is run: ``GOOGLE_EXTERNAL_ACCOUNT_AUDIENCE``: The audience -field from the credential configuration. Always present. +field from the credential configuration. Always present. ``GOOGLE_EXTERNAL_ACCOUNT_IMPERSONATED_EMAIL``: The service account -email. Only present when service account impersonation is used. +email. Only present when service account impersonation is used. ``GOOGLE_EXTERNAL_ACCOUNT_OUTPUT_FILE``: The output file location from the credential configuration. Only present when specified in the credential configuration. @@ -486,6 +486,117 @@ they do not meet your specific requirements. You can now `use the Auth library <#using-external-identities>`__ to call Google Cloud resources from an OIDC or SAML provider. + +Accessing resources using a custom supplier with OIDC or SAML +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +This library also allows for a custom implementation of :class:`google.auth.identity_pool.SubjectTokenSupplier` +to be specificed when creating a :class:`google.auth.identity_pool.Credential`. The supplier must +return a valid OIDC or SAML2.0 subject token, which will then be exchanged for a +Google Cloud access token. If an error occurs during token retrieval, the supplier +should return a :class:`google.auth.exceptions.RefreshError` and indicate via the error +whether the subject token retrieval is retryable. +Any call to the supplier from the Identity Pool credential will send a :class:`google.auth.external_account.SupplierContext` +object, which contains the requested audience and subject type. Additionally, the credential will +send the :class:`google.auth.transport.requests.Request` passed in the credential refresh call which +can be used to make HTTP requests.:: + + from google.auth import exceptions + from google.auth import identity_pool + + class CustomSubjectTokenSupplier(identity_pool.SubjectTokenSupplier): + + def get_subject_token(self, context, request): + audience = context.audience + subject_token_type = context.subject_token_type + try: + # Attempt to return the valid subject token of the requested type for the requested audience. + except Exception as e: + # If token retrieval fails, raise a refresh error, setting retryable to true if the client should + # attempt to retrieve the subject token again. + raise exceptions.RefreshError(e, retryable=True) + + supplier = CustomSubjectTokenSupplier() + + credentials = identity_pool.Credentials( + AUDIENCE, # Set GCP Audience. + "urn:ietf:params:aws:token-type:jwt", # Set subject token type. + subject_token_supplier=supplier, # Set supplier. + scopes=SCOPES # Set desired scopes. + ) + +Where the `audience`_ is: ``///iam.googleapis.com/projects/PROJECT_NUMBER/locations/global/workloadIdentityPools/POOL_ID/providers/PROVIDER_ID`` +Where the following variables need to be substituted: + +* ``$PROJECT_NUMBER``: The project number. +* ``$POOL_ID``: The workload pool ID. +* ``$PROVIDER_ID``: The provider ID. + +The values for audience, service account impersonation URL, and any other builder field can also be found +by generating a `credential configuration file with the gcloud CLI`_. + +.. _audience: + https://cloud.google.com/iam/docs/best-practices-for-using-workload-identity-federation#provider-audience +.. _credential configuration file with the gcloud CLI: + https://cloud.google.com/sdk/gcloud/reference/iam/workload-identity-pools/create-cred-config + +Accessing resources using a custom supplier with AWS +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +This library also allows for a custom implementation of :class:`google.auth.aws.AwsSecurityCredentialsSupplier` +to be specificed when creating a :class:`google.auth.aws.Credential`. The supplier must +return valid AWS security credentials, which will then be exchanged for a +Google Cloud access token. If an error occurs during credential retrieval, the supplier +should return a :class:`google.auth.exceptions.RefreshError` and indicate via the error +whether the credential retrieval is retryable. +Any call to the supplier from the Identity Pool credential will send a :class:`google.auth.external_account.SupplierContext` +object, which contains the requested audience and subject type. Additionally, the credential will +send the :class:`google.auth.transport.requests.Request` passed in the credential refresh call which +can be used to make HTTP requests.:: + + from google.auth import aws + from google.auth import exceptions + + class CustomAwsSecurityCredentialsSupplier(aws.AwsSecurityCredentialsSupplier): + + def get_aws_security_credentials(self, context, request): + audience = context.audience + try: + # Return valid AWS security credentials. These credentials are not cached by + # the google credential, so caching should be implemented in the supplier. + return aws.AwsSecurityCredentials(ACCESS_KEY_ID, SECRET_ACCESS_KEY, SESSION_TOKEN) + except Exception as e: + # If credentials retrieval fails, raise a refresh error, setting retryable to true if the client should + # attempt to retrieve the subject token again. + raise exceptions.RefreshError(e, retryable=True) + + def get_aws_region(self, context, request): + # Return active AWS region. + + supplier = CustomAwsSecurityCredentialsSupplier() + + credentials = aws.Credentials( + AUDIENCE, # Set GCP Audience. + "urn:ietf:params:aws:token-type:aws4_request", # Set AWS subject token type. + aws_security_token_supplier=supplier, # Set supplier. + scopes=SCOPES # Set desired scopes. + ) + +Where the `audience`_ is: ``///iam.googleapis.com/projects/PROJECT_NUMBER/locations/global/workloadIdentityPools/POOL_ID/providers/PROVIDER_ID`` +Where the following variables need to be substituted: + +* ``$PROJECT_NUMBER``: The project number. +* ``$POOL_ID``: The workload pool ID. +* ``$PROVIDER_ID``: The provider ID. + +The values for audience, service account impersonation URL, and any other builder field can also be found +by generating a `credential configuration file with the gcloud CLI`_. + +.. _audience: + https://cloud.google.com/iam/docs/best-practices-for-using-workload-identity-federation#provider-audience +.. _credential configuration file with the gcloud CLI: + https://cloud.google.com/sdk/gcloud/reference/iam/workload-identity-pools/create-cred-config + Using External Identities ~~~~~~~~~~~~~~~~~~~~~~~~~ @@ -774,6 +885,62 @@ Refer to the `using executable-sourced credentials with Workload Identity Federation `__ above for the executable response specification. +Accessing resources using a custom supplier with OIDC or SAML +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +This library also allows for a custom implementation of :class:`google.auth.identity_pool.SubjectTokenSupplier` +to be specificed when creating a :class:`google.auth.identity_pool.Credential`. The supplier must +return a valid OIDC or SAML2.0 subject token, which will then be exchanged for a +Google Cloud access token. If an error occurs during token retrieval, the supplier +should return a :class:`google.auth.exceptions.RefreshError` and indicate via the error +whether the subject token retrieval is retryable. +Any call to the supplier from the Identity Pool credential will send a :class:`google.auth.external_account.SupplierContext` +object, which contains the requested audience and subject type. Additionally, the credential will +send the :class:`google.auth.transport.requests.Request` passed in the credential refresh call which +can be used to make HTTP requests.:: + + from google.auth import exceptions + from google.auth import identity_pool + + class CustomSubjectTokenSupplier(identity_pool.SubjectTokenSupplier): + + def get_subject_token(self, context, request): + audience = context.audience + subject_token_type = context.subject_token_type + try: + # Attempt to return the valid subject token of the requested type for the requested audience. + except Exception as e: + # If token retrieval fails, raise a refresh error, setting retryable to true if the client should + # attempt to retrieve the subject token again. + raise exceptions.RefreshError(e, retryable=True) + + + supplier = CustomSubjectTokenSupplier() + + credentials = identity_pool.Credentials( + AUDIENCE, # Set GCP Audience. + "urn:ietf:params:aws:token-type:jwt", # Set subject token type. + subject_token_supplier=supplier, # Set supplier. + scopes=SCOPES, # Set desired scopes. + workforce_pool_user_project=USER_PROJECT # Set workforce pool user project. + ) + +Where the audience is: ``//iam.googleapis.com/locations/global/workforcePools/$WORKFORCE_POOL_ID/providers/$PROVIDER_ID`` +Where the following variables need to be substituted: + +* ``$WORKFORCE_POOL_ID``: The workforce pool ID. +* ``$PROVIDER_ID``: The provider ID. + +and the workforce pool user project is the project number associated with the `workforce pools user project`_. + +The values for audience, service account impersonation URL, and any other builder field can also be found +by generating a `credential configuration file`_ with the gcloud CLI. + +.. _workforce pools user project: + https://cloud.google.com/iam/docs/workforce-identity-federation#workforce-pools-user-project +.. _credential configuration file: + https://cloud.google.com/iam/docs/workforce-obtaining-short-lived-credentials#use_configuration_files_for_sign-in + Security considerations ~~~~~~~~~~~~~~~~~~~~~~~ @@ -814,7 +981,7 @@ Impersonated credentials ++++++++++++++++++++++++ Impersonated Credentials allows one set of credentials issued to a user or service account -to impersonate another. The source credentials must be granted +to impersonate another. The source credentials must be granted the "Service Account Token Creator" IAM role. :: from google.auth import impersonated_credentials @@ -884,7 +1051,7 @@ Token broker :: credential_access_boundary = downscoped.CredentialAccessBoundary( rules=[rule]) - # Retrieve the source credentials via ADC. + # Retrieve the source credentials via ADC. source_credentials, _ = google.auth.default() # Create the downscoped credentials. diff --git a/google/auth/aws.py b/google/auth/aws.py index 6e0e4e864..14ac8fc9a 100644 --- a/google/auth/aws.py +++ b/google/auth/aws.py @@ -21,10 +21,11 @@ AWS Credentials are initialized using external_account arguments which are typically loaded from the external credentials JSON file. -Unlike other Credentials that can be initialized with a list of explicit -arguments, secrets or credentials, external account clients use the -environment and hints/guidelines provided by the external_account JSON -file to retrieve credentials and exchange them for Google access tokens. + +This module also provides a definition for an abstract AWS security credentials supplier. +This supplier can be implemented to return valid AWS security credentials and an AWS region +and used to create AWS credentials. The credentials will then call the +supplier instead of using pre-defined methods such as calling the EC2 metadata endpoints. This module also provides a basic implementation of the `AWS Signature Version 4`_ request signing algorithm. @@ -37,6 +38,8 @@ .. _AWS STS GetCallerIdentity: https://docs.aws.amazon.com/STS/latest/APIReference/API_GetCallerIdentity.html """ +import abc +from dataclasses import dataclass import hashlib import hmac import http.client as http_client @@ -44,6 +47,7 @@ import os import posixpath import re +from typing import Optional import urllib from urllib.parse import urljoin @@ -61,6 +65,10 @@ _AWS_SECURITY_TOKEN_HEADER = "x-amz-security-token" # The AWS authorization header name for the auto-generated date. _AWS_DATE_HEADER = "x-amz-date" +# The default AWS regional credential verification URL. +_DEFAULT_AWS_REGIONAL_CREDENTIAL_VERIFICATION_URL = ( + "https://sts.{region}.amazonaws.com?Action=GetCallerIdentity&Version=2011-06-15" +) class RequestSigner(object): @@ -92,8 +100,7 @@ def get_request_options( https://docs.aws.amazon.com/general/latest/gr/sigv4_signing.html Args: - aws_security_credentials (Mapping[str, str]): A dictionary containing - the AWS security credentials. + aws_security_credentials (AWSSecurityCredentials): The AWS security credentials. url (str): The AWS service URL containing the canonical URI and query string. method (str): The HTTP method used to call this API. @@ -105,10 +112,6 @@ def get_request_options( Returns: Mapping[str, str]: The AWS signed request dictionary object. """ - # Get AWS credentials. - access_key = aws_security_credentials.get("access_key_id") - secret_key = aws_security_credentials.get("secret_access_key") - security_token = aws_security_credentials.get("security_token") additional_headers = additional_headers or {} @@ -129,9 +132,7 @@ def get_request_options( canonical_querystring=_get_canonical_querystring(uri.query), method=method, region=self._region_name, - access_key=access_key, - secret_key=secret_key, - security_token=security_token, + aws_security_credentials=aws_security_credentials, request_payload=request_payload, additional_headers=additional_headers, ) @@ -147,8 +148,8 @@ def get_request_options( headers[key] = additional_headers[key] # Add session token if available. - if security_token is not None: - headers[_AWS_SECURITY_TOKEN_HEADER] = security_token + if aws_security_credentials.session_token is not None: + headers[_AWS_SECURITY_TOKEN_HEADER] = aws_security_credentials.session_token signed_request = {"url": url, "method": method, "headers": headers} if request_payload: @@ -233,9 +234,7 @@ def _generate_authentication_header_map( canonical_querystring, method, region, - access_key, - secret_key, - security_token, + aws_security_credentials, request_payload="", additional_headers={}, ): @@ -248,10 +247,7 @@ def _generate_authentication_header_map( canonical_querystring (str): The AWS service URL query string. method (str): The HTTP method used to call this API. region (str): The AWS region. - access_key (str): The AWS access key ID. - secret_key (str): The AWS secret access key. - security_token (Optional[str]): The AWS security session token. This is - available for temporary sessions. + aws_security_credentials (AWSSecurityCredentials): The AWS security credentials. request_payload (Optional[str]): The optional request payload if available. additional_headers (Optional[Mapping[str, str]]): The optional @@ -274,8 +270,10 @@ def _generate_authentication_header_map( for key in additional_headers: full_headers[key.lower()] = additional_headers[key] # Add AWS session token if available. - if security_token is not None: - full_headers[_AWS_SECURITY_TOKEN_HEADER] = security_token + if aws_security_credentials.session_token is not None: + full_headers[ + _AWS_SECURITY_TOKEN_HEADER + ] = aws_security_credentials.session_token # Required headers full_headers["host"] = host @@ -321,14 +319,20 @@ def _generate_authentication_header_map( ) # https://docs.aws.amazon.com/general/latest/gr/sigv4-calculate-signature.html - signing_key = _get_signing_key(secret_key, date_stamp, region, service_name) + signing_key = _get_signing_key( + aws_security_credentials.secret_access_key, date_stamp, region, service_name + ) signature = hmac.new( signing_key, string_to_sign.encode("utf-8"), hashlib.sha256 ).hexdigest() # https://docs.aws.amazon.com/general/latest/gr/sigv4-add-signature-to-request.html authorization_header = "{} Credential={}/{}, SignedHeaders={}, Signature={}".format( - _AWS_ALGORITHM, access_key, credential_scope, signed_headers, signature + _AWS_ALGORITHM, + aws_security_credentials.access_key_id, + credential_scope, + signed_headers, + signature, ) authentication_header = {"authorization_header": authorization_header} @@ -338,211 +342,112 @@ def _generate_authentication_header_map( return authentication_header -class Credentials(external_account.Credentials): - """AWS external account credentials. - This is used to exchange serialized AWS signature v4 signed requests to - AWS STS GetCallerIdentity service for Google access tokens. - """ - - def __init__( - self, - audience, - subject_token_type, - token_url, - credential_source=None, - *args, - **kwargs - ): - """Instantiates an AWS workload external account credentials object. - - Args: - audience (str): The STS audience field. - subject_token_type (str): The subject token type. - token_url (str): The STS endpoint URL. - credential_source (Mapping): The credential source dictionary used - to provide instructions on how to retrieve external credential - to be exchanged for Google access tokens. - args (List): Optional positional arguments passed into the underlying :meth:`~external_account.Credentials.__init__` method. - kwargs (Mapping): Optional keyword arguments passed into the underlying :meth:`~external_account.Credentials.__init__` method. - - Raises: - google.auth.exceptions.RefreshError: If an error is encountered during - access token retrieval logic. - ValueError: For invalid parameters. +@dataclass +class AwsSecurityCredentials: + """A class that models AWS security credentials with an optional session token. - .. note:: Typically one of the helper constructors - :meth:`from_file` or - :meth:`from_info` are used instead of calling the constructor directly. - """ - super(Credentials, self).__init__( - audience=audience, - subject_token_type=subject_token_type, - token_url=token_url, - credential_source=credential_source, - *args, - **kwargs - ) - credential_source = credential_source or {} - self._environment_id = credential_source.get("environment_id") or "" - self._region_url = credential_source.get("region_url") - self._security_credentials_url = credential_source.get("url") - self._cred_verification_url = credential_source.get( - "regional_cred_verification_url" - ) - self._imdsv2_session_token_url = credential_source.get( - "imdsv2_session_token_url" - ) - self._region = None - self._request_signer = None - self._target_resource = audience + Attributes: + access_key_id (str): The AWS security credentials access key id. + secret_access_key (str): The AWS security credentials secret access key. + session_token (Optional[str]): The optional AWS security credentials session token. This should be set when using temporary credentials. + """ - # Get the environment ID. Currently, only one version supported (v1). - matches = re.match(r"^(aws)([\d]+)$", self._environment_id) - if matches: - env_id, env_version = matches.groups() - else: - env_id, env_version = (None, None) + access_key_id: str + secret_access_key: str + session_token: Optional[str] = None - if env_id != "aws" or self._cred_verification_url is None: - raise exceptions.InvalidResource( - "No valid AWS 'credential_source' provided" - ) - elif int(env_version or "") != 1: - raise exceptions.InvalidValue( - "aws version '{}' is not supported in the current build.".format( - env_version - ) - ) - def retrieve_subject_token(self, request): - """Retrieves the subject token using the credential_source object. - The subject token is a serialized `AWS GetCallerIdentity signed request`_. +class AwsSecurityCredentialsSupplier(metaclass=abc.ABCMeta): + """Base class for AWS security credential suppliers. This can be implemented with custom logic to retrieve + AWS security credentials to exchange for a Google Cloud access token. The AWS external account credential does + not cache the AWS security credentials, so caching logic should be added in the implementation. + """ - The logic is summarized as: + @abc.abstractmethod + def get_aws_security_credentials(self, context, request): + """Returns the AWS security credentials for the requested context. - Retrieve the AWS region from the AWS_REGION or AWS_DEFAULT_REGION - environment variable or from the AWS metadata server availability-zone - if not found in the environment variable. + .. warning: This is not cached by the calling Google credential, so caching logic should be implemented in the supplier. - Check AWS credentials in environment variables. If not found, retrieve - from the AWS metadata server security-credentials endpoint. - - When retrieving AWS credentials from the metadata server - security-credentials endpoint, the AWS role needs to be determined by - calling the security-credentials endpoint without any argument. Then the - credentials can be retrieved via: security-credentials/role_name + Args: + context (google.auth.externalaccount.SupplierContext): The context object + containing information about the requested audience and subject token type. + request (google.auth.transport.Request): The object used to make + HTTP requests. - Generate the signed request to AWS STS GetCallerIdentity action. + Raises: + google.auth.exceptions.RefreshError: If an error is encountered during + security credential retrieval logic. - Inject x-goog-cloud-target-resource into header and serialize the - signed request. This will be the subject-token to pass to GCP STS. + Returns: + AwsSecurityCredentials: The requested AWS security credentials. + """ + raise NotImplementedError("") - .. _AWS GetCallerIdentity signed request: - https://cloud.google.com/iam/docs/access-resources-aws#exchange-token + @abc.abstractmethod + def get_aws_region(self, context, request): + """Returns the AWS region for the requested context. Args: - request (google.auth.transport.Request): A callable used to make + context (google.auth.externalaccount.SupplierContext): The context object + containing information about the requested audience and subject token type. + request (google.auth.transport.Request): The object used to make HTTP requests. - Returns: - str: The retrieved subject token. - """ - # Fetch the session token required to make meta data endpoint calls to aws. - if ( - request is not None - and self._imdsv2_session_token_url is not None - and self._should_use_metadata_server() - ): - headers = {"X-aws-ec2-metadata-token-ttl-seconds": "300"} - imdsv2_session_token_response = request( - url=self._imdsv2_session_token_url, method="PUT", headers=headers - ) + Raises: + google.auth.exceptions.RefreshError: If an error is encountered during + region retrieval logic. - if imdsv2_session_token_response.status != 200: - raise exceptions.RefreshError( - "Unable to retrieve AWS Session Token", - imdsv2_session_token_response.data, - ) + Returns: + str: The AWS region. + """ + raise NotImplementedError("") - imdsv2_session_token = imdsv2_session_token_response.data - else: - imdsv2_session_token = None - # Initialize the request signer if not yet initialized after determining - # the current AWS region. - if self._request_signer is None: - self._region = self._get_region( - request, self._region_url, imdsv2_session_token - ) - self._request_signer = RequestSigner(self._region) +class _DefaultAwsSecurityCredentialsSupplier(AwsSecurityCredentialsSupplier): + """Default implementation of AWS security credentials supplier. Supports retrieving + credentials and region via EC2 metadata endpoints and environment variables. + """ - # Retrieve the AWS security credentials needed to generate the signed - # request. - aws_security_credentials = self._get_security_credentials( - request, imdsv2_session_token - ) - # Generate the signed request to AWS STS GetCallerIdentity API. - # Use the required regional endpoint. Otherwise, the request will fail. - request_options = self._request_signer.get_request_options( - aws_security_credentials, - self._cred_verification_url.replace("{region}", self._region), - "POST", + def __init__(self, credential_source): + self._region_url = credential_source.get("region_url") + self._security_credentials_url = credential_source.get("url") + self._imdsv2_session_token_url = credential_source.get( + "imdsv2_session_token_url" ) - # The GCP STS endpoint expects the headers to be formatted as: - # [ - # {key: 'x-amz-date', value: '...'}, - # {key: 'Authorization', value: '...'}, - # ... - # ] - # And then serialized as: - # quote(json.dumps({ - # url: '...', - # method: 'POST', - # headers: [{key: 'x-amz-date', value: '...'}, ...] - # })) - request_headers = request_options.get("headers") - # The full, canonical resource name of the workload identity pool - # provider, with or without the HTTPS prefix. - # Including this header as part of the signature is recommended to - # ensure data integrity. - request_headers["x-goog-cloud-target-resource"] = self._target_resource - # Serialize AWS signed request. - # Keeping inner keys in sorted order makes testing easier for Python - # versions <=3.5 as the stringified JSON string would have a predictable - # key order. - aws_signed_req = {} - aws_signed_req["url"] = request_options.get("url") - aws_signed_req["method"] = request_options.get("method") - aws_signed_req["headers"] = [] - # Reformat header to GCP STS expected format. - for key in sorted(request_headers.keys()): - aws_signed_req["headers"].append( - {"key": key, "value": request_headers[key]} - ) + @_helpers.copy_docstring(AwsSecurityCredentialsSupplier) + def get_aws_security_credentials(self, context, request): - return urllib.parse.quote( - json.dumps(aws_signed_req, separators=(",", ":"), sort_keys=True) + # Check environment variables for permanent credentials first. + # https://docs.aws.amazon.com/general/latest/gr/aws-sec-cred-types.html + env_aws_access_key_id = os.environ.get(environment_vars.AWS_ACCESS_KEY_ID) + env_aws_secret_access_key = os.environ.get( + environment_vars.AWS_SECRET_ACCESS_KEY ) + # This is normally not available for permanent credentials. + env_aws_session_token = os.environ.get(environment_vars.AWS_SESSION_TOKEN) + if env_aws_access_key_id and env_aws_secret_access_key: + return AwsSecurityCredentials( + env_aws_access_key_id, env_aws_secret_access_key, env_aws_session_token + ) - def _get_region(self, request, url, imdsv2_session_token): - """Retrieves the current AWS region from either the AWS_REGION or - AWS_DEFAULT_REGION environment variable or from the AWS metadata server. + imdsv2_session_token = self._get_imdsv2_session_token(request) + role_name = self._get_metadata_role_name(request, imdsv2_session_token) - Args: - request (google.auth.transport.Request): A callable used to make - HTTP requests. - url (str): The AWS metadata server region URL. - imdsv2_session_token (str): The AWS IMDSv2 session token to be added as a - header in the requests to AWS metadata endpoint. + # Get security credentials. + credentials = self._get_metadata_security_credentials( + request, role_name, imdsv2_session_token + ) - Returns: - str: The current AWS region. + return AwsSecurityCredentials( + credentials.get("AccessKeyId"), + credentials.get("SecretAccessKey"), + credentials.get("Token"), + ) - Raises: - google.auth.exceptions.RefreshError: If an error occurs while - retrieving the AWS region. - """ + @_helpers.copy_docstring(AwsSecurityCredentialsSupplier) + def get_aws_region(self, context, request): # The AWS metadata server is not available in some AWS environments # such as AWS lambda. Instead, it is available via environment # variable. @@ -558,6 +463,7 @@ def _get_region(self, request, url, imdsv2_session_token): raise exceptions.RefreshError("Unable to determine AWS region") headers = None + imdsv2_session_token = self._get_imdsv2_session_token(request) if imdsv2_session_token is not None: headers = {"X-aws-ec2-metadata-token": imdsv2_session_token} @@ -579,53 +485,23 @@ def _get_region(self, request, url, imdsv2_session_token): # Only the us-east-2 part should be used. return response_body[:-1] - def _get_security_credentials(self, request, imdsv2_session_token): - """Retrieves the AWS security credentials required for signing AWS - requests from either the AWS security credentials environment variables - or from the AWS metadata server. - - Args: - request (google.auth.transport.Request): A callable used to make - HTTP requests. - imdsv2_session_token (str): The AWS IMDSv2 session token to be added as a - header in the requests to AWS metadata endpoint. - - Returns: - Mapping[str, str]: The AWS security credentials dictionary object. - - Raises: - google.auth.exceptions.RefreshError: If an error occurs while - retrieving the AWS security credentials. - """ - - # Check environment variables for permanent credentials first. - # https://docs.aws.amazon.com/general/latest/gr/aws-sec-cred-types.html - env_aws_access_key_id = os.environ.get(environment_vars.AWS_ACCESS_KEY_ID) - env_aws_secret_access_key = os.environ.get( - environment_vars.AWS_SECRET_ACCESS_KEY - ) - # This is normally not available for permanent credentials. - env_aws_session_token = os.environ.get(environment_vars.AWS_SESSION_TOKEN) - if env_aws_access_key_id and env_aws_secret_access_key: - return { - "access_key_id": env_aws_access_key_id, - "secret_access_key": env_aws_secret_access_key, - "security_token": env_aws_session_token, - } + def _get_imdsv2_session_token(self, request): + if request is not None and self._imdsv2_session_token_url is not None: + headers = {"X-aws-ec2-metadata-token-ttl-seconds": "300"} - # Get role name. - role_name = self._get_metadata_role_name(request, imdsv2_session_token) + imdsv2_session_token_response = request( + url=self._imdsv2_session_token_url, method="PUT", headers=headers + ) - # Get security credentials. - credentials = self._get_metadata_security_credentials( - request, role_name, imdsv2_session_token - ) + if imdsv2_session_token_response.status != 200: + raise exceptions.RefreshError( + "Unable to retrieve AWS Session Token", + imdsv2_session_token_response.data, + ) - return { - "access_key_id": credentials.get("AccessKeyId"), - "secret_access_key": credentials.get("SecretAccessKey"), - "security_token": credentials.get("Token"), - } + return imdsv2_session_token_response.data + else: + return None def _get_metadata_security_credentials( self, request, role_name, imdsv2_session_token @@ -722,30 +598,230 @@ def _get_metadata_role_name(self, request, imdsv2_session_token): return response_body - def _should_use_metadata_server(self): - # The AWS region can be provided through AWS_REGION or AWS_DEFAULT_REGION. - # The metadata server should be used if it cannot be retrieved from one of - # these environment variables. - if not os.environ.get(environment_vars.AWS_REGION) and not os.environ.get( - environment_vars.AWS_DEFAULT_REGION - ): - return True - # AWS security credentials can be retrieved from the AWS_ACCESS_KEY_ID - # and AWS_SECRET_ACCESS_KEY environment variables. The metadata server - # should be used if either of these are not available. - if not os.environ.get(environment_vars.AWS_ACCESS_KEY_ID) or not os.environ.get( - environment_vars.AWS_SECRET_ACCESS_KEY +class Credentials(external_account.Credentials): + """AWS external account credentials. + This is used to exchange serialized AWS signature v4 signed requests to + AWS STS GetCallerIdentity service for Google access tokens. + """ + + def __init__( + self, + audience, + subject_token_type, + token_url=external_account._DEFAULT_TOKEN_URL, + credential_source=None, + aws_security_credentials_supplier=None, + *args, + **kwargs + ): + """Instantiates an AWS workload external account credentials object. + + Args: + audience (str): The STS audience field. + subject_token_type (str): The subject token type based on the Oauth2.0 token exchange spec. + Expected values include:: + + “urn:ietf:params:aws:token-type:aws4_request” + + token_url (Optional [str]): The STS endpoint URL. If not provided, will default to "https://sts.googleapis.com/v1/token". + credential_source (Optional [Mapping]): The credential source dictionary used + to provide instructions on how to retrieve external credential to be exchanged for Google access tokens. + Either a credential source or an AWS security credentials supplier must be provided. + + Example credential_source for AWS credential:: + + { + "environment_id": "aws1", + "regional_cred_verification_url": "https://sts.{region}.amazonaws.com?Action=GetCallerIdentity&Version=2011-06-15", + "region_url": "http://169.254.169.254/latest/meta-data/placement/availability-zone", + "url": "http://169.254.169.254/latest/meta-data/iam/security-credentials", + imdsv2_session_token_url": "http://169.254.169.254/latest/api/token" + } + + aws_security_credentials_supplier (Optional [AwsSecurityCredentialsSupplier]): Optional AWS security credentials supplier. + This will be called to supply valid AWS security credentails which will then + be exchanged for Google access tokens. Either an AWS security credentials supplier + or a credential source must be provided. + args (List): Optional positional arguments passed into the underlying :meth:`~external_account.Credentials.__init__` method. + kwargs (Mapping): Optional keyword arguments passed into the underlying :meth:`~external_account.Credentials.__init__` method. + + Raises: + google.auth.exceptions.RefreshError: If an error is encountered during + access token retrieval logic. + ValueError: For invalid parameters. + + .. note:: Typically one of the helper constructors + :meth:`from_file` or + :meth:`from_info` are used instead of calling the constructor directly. + """ + super(Credentials, self).__init__( + audience=audience, + subject_token_type=subject_token_type, + token_url=token_url, + credential_source=credential_source, + *args, + **kwargs + ) + if credential_source is None and aws_security_credentials_supplier is None: + raise exceptions.InvalidValue( + "A valid credential source or AWS security credentials supplier must be provided." + ) + if ( + credential_source is not None + and aws_security_credentials_supplier is not None ): - return True + raise exceptions.InvalidValue( + "AWS credential cannot have both a credential source and an AWS security credentials supplier." + ) + + if aws_security_credentials_supplier: + self._aws_security_credentials_supplier = aws_security_credentials_supplier + # The regional cred verification URL would normally be provided through the credential source. So set it to the default one here. + self._cred_verification_url = ( + _DEFAULT_AWS_REGIONAL_CREDENTIAL_VERIFICATION_URL + ) + else: + environment_id = credential_source.get("environment_id") or "" + self._aws_security_credentials_supplier = _DefaultAwsSecurityCredentialsSupplier( + credential_source + ) + self._cred_verification_url = credential_source.get( + "regional_cred_verification_url" + ) + + # Get the environment ID. Currently, only one version supported (v1). + matches = re.match(r"^(aws)([\d]+)$", environment_id) + if matches: + env_id, env_version = matches.groups() + else: + env_id, env_version = (None, None) + + if env_id != "aws" or self._cred_verification_url is None: + raise exceptions.InvalidResource( + "No valid AWS 'credential_source' provided" + ) + elif int(env_version or "") != 1: + raise exceptions.InvalidValue( + "aws version '{}' is not supported in the current build.".format( + env_version + ) + ) + + self._target_resource = audience + self._request_signer = None + + def retrieve_subject_token(self, request): + """Retrieves the subject token using the credential_source object. + The subject token is a serialized `AWS GetCallerIdentity signed request`_. + + The logic is summarized as: + + Retrieve the AWS region from the AWS_REGION or AWS_DEFAULT_REGION + environment variable or from the AWS metadata server availability-zone + if not found in the environment variable. + + Check AWS credentials in environment variables. If not found, retrieve + from the AWS metadata server security-credentials endpoint. + + When retrieving AWS credentials from the metadata server + security-credentials endpoint, the AWS role needs to be determined by + calling the security-credentials endpoint without any argument. Then the + credentials can be retrieved via: security-credentials/role_name + + Generate the signed request to AWS STS GetCallerIdentity action. + + Inject x-goog-cloud-target-resource into header and serialize the + signed request. This will be the subject-token to pass to GCP STS. + + .. _AWS GetCallerIdentity signed request: + https://cloud.google.com/iam/docs/access-resources-aws#exchange-token - return False + Args: + request (google.auth.transport.Request): A callable used to make + HTTP requests. + Returns: + str: The retrieved subject token. + """ + + # Initialize the request signer if not yet initialized after determining + # the current AWS region. + if self._request_signer is None: + self._region = self._aws_security_credentials_supplier.get_aws_region( + self._supplier_context, request + ) + self._request_signer = RequestSigner(self._region) + + # Retrieve the AWS security credentials needed to generate the signed + # request. + aws_security_credentials = self._aws_security_credentials_supplier.get_aws_security_credentials( + self._supplier_context, request + ) + # Generate the signed request to AWS STS GetCallerIdentity API. + # Use the required regional endpoint. Otherwise, the request will fail. + request_options = self._request_signer.get_request_options( + aws_security_credentials, + self._cred_verification_url.replace("{region}", self._region), + "POST", + ) + # The GCP STS endpoint expects the headers to be formatted as: + # [ + # {key: 'x-amz-date', value: '...'}, + # {key: 'Authorization', value: '...'}, + # ... + # ] + # And then serialized as: + # quote(json.dumps({ + # url: '...', + # method: 'POST', + # headers: [{key: 'x-amz-date', value: '...'}, ...] + # })) + request_headers = request_options.get("headers") + # The full, canonical resource name of the workload identity pool + # provider, with or without the HTTPS prefix. + # Including this header as part of the signature is recommended to + # ensure data integrity. + request_headers["x-goog-cloud-target-resource"] = self._target_resource + + # Serialize AWS signed request. + # Keeping inner keys in sorted order makes testing easier for Python + # versions <=3.5 as the stringified JSON string would have a predictable + # key order. + aws_signed_req = {} + aws_signed_req["url"] = request_options.get("url") + aws_signed_req["method"] = request_options.get("method") + aws_signed_req["headers"] = [] + # Reformat header to GCP STS expected format. + for key in sorted(request_headers.keys()): + aws_signed_req["headers"].append( + {"key": key, "value": request_headers[key]} + ) + + return urllib.parse.quote( + json.dumps(aws_signed_req, separators=(",", ":"), sort_keys=True) + ) def _create_default_metrics_options(self): metrics_options = super(Credentials, self)._create_default_metrics_options() metrics_options["source"] = "aws" + if self._has_custom_supplier(): + metrics_options["source"] = "programmatic" return metrics_options + def _has_custom_supplier(self): + return self._credential_source is None + + def _constructor_args(self): + args = super(Credentials, self)._constructor_args() + # If a custom supplier was used, append it to the args dict. + if self._has_custom_supplier(): + args.update( + { + "aws_security_credentials_supplier": self._aws_security_credentials_supplier + } + ) + return args + @classmethod def from_info(cls, info, **kwargs): """Creates an AWS Credentials instance from parsed external account info. @@ -761,6 +837,12 @@ def from_info(cls, info, **kwargs): Raises: ValueError: For invalid parameters. """ + aws_security_credentials_supplier = info.get( + "aws_security_credentials_supplier" + ) + kwargs.update( + {"aws_security_credentials_supplier": aws_security_credentials_supplier} + ) return super(Credentials, cls).from_info(info, **kwargs) @classmethod diff --git a/google/auth/external_account.py b/google/auth/external_account.py index 0420883f8..c14001bc2 100644 --- a/google/auth/external_account.py +++ b/google/auth/external_account.py @@ -29,6 +29,7 @@ import abc import copy +from dataclasses import dataclass import datetime import io import json @@ -50,6 +51,29 @@ _STS_REQUESTED_TOKEN_TYPE = "urn:ietf:params:oauth:token-type:access_token" # Cloud resource manager URL used to retrieve project information. _CLOUD_RESOURCE_MANAGER = "https://cloudresourcemanager.googleapis.com/v1/projects/" +# Default Google sts token url. +_DEFAULT_TOKEN_URL = "https://sts.googleapis.com/v1/token" + + +@dataclass +class SupplierContext: + """A context class that contains information about the requested third party credential that is passed + to AWS security credential and subject token suppliers. + + Attributes: + subject_token_type (str): The requested subject token type based on the Oauth2.0 token exchange spec. + Expected values include:: + + “urn:ietf:params:oauth:token-type:jwt” + “urn:ietf:params:oauth:token-type:id-token” + “urn:ietf:params:oauth:token-type:saml2” + “urn:ietf:params:aws:token-type:aws4_request” + + audience (str): The requested audience for the subject token. + """ + + subject_token_type: str + audience: str class Credentials( @@ -88,7 +112,14 @@ def __init__( Args: audience (str): The STS audience field. - subject_token_type (str): The subject token type. + subject_token_type (str): The subject token type based on the Oauth2.0 token exchange spec. + Expected values include:: + + “urn:ietf:params:oauth:token-type:jwt” + “urn:ietf:params:oauth:token-type:id-token” + “urn:ietf:params:oauth:token-type:saml2” + “urn:ietf:params:aws:token-type:aws4_request” + token_url (str): The STS endpoint URL. credential_source (Mapping): The credential source dictionary. service_account_impersonation_url (Optional[str]): The optional service account @@ -145,11 +176,11 @@ def __init__( self._metrics_options = self._create_default_metrics_options() - if self._service_account_impersonation_url: - self._impersonated_credentials = self._initialize_impersonated_credentials() - else: - self._impersonated_credentials = None + self._impersonated_credentials = None self._project_id = None + self._supplier_context = SupplierContext( + self._subject_token_type, self._audience + ) if not self.is_workforce_pool and self._workforce_pool_user_project: # Workload identity pools do not support workforce pool user projects. @@ -358,6 +389,10 @@ def get_project_id(self, request): @_helpers.copy_docstring(credentials.Credentials) def refresh(self, request): scopes = self._scopes if self._scopes is not None else self._default_scopes + + if self._should_initialize_impersonated_credentials(): + self._impersonated_credentials = self._initialize_impersonated_credentials() + if self._impersonated_credentials: self._impersonated_credentials.refresh(request) self.token = self._impersonated_credentials.token @@ -421,6 +456,12 @@ def with_universe_domain(self, universe_domain): new_cred._metrics_options = self._metrics_options return new_cred + def _should_initialize_impersonated_credentials(self): + return ( + self._service_account_impersonation_url is not None + and self._impersonated_credentials is None + ) + def _initialize_impersonated_credentials(self): """Generates an impersonated credentials. diff --git a/google/auth/identity_pool.py b/google/auth/identity_pool.py index a515353c3..5526e775c 100644 --- a/google/auth/identity_pool.py +++ b/google/auth/identity_pool.py @@ -26,11 +26,13 @@ Identity Pool Credentials are initialized using external_account arguments which are typically loaded from an external credentials file or -an external credentials URL. Unlike other Credentials that can be initialized -with a list of explicit arguments, secrets or credentials, external account -clients use the environment and hints/guidelines provided by the -external_account JSON file to retrieve credentials and exchange them for Google -access tokens. +an external credentials URL. + +This module also provides a definition for an abstract subject token supplier. +This supplier can be implemented to return a valid OIDC or SAML2.0 subject token +and used to create Identity Pool credentials. The credentials will then call the +supplier instead of using pre-defined methods such as reading a local file or +calling a URL. """ try: @@ -38,15 +40,130 @@ # Python 2.7 compatibility except ImportError: # pragma: NO COVER from collections import Mapping +import abc import io import json import os +from typing import NamedTuple from google.auth import _helpers from google.auth import exceptions from google.auth import external_account +class SubjectTokenSupplier(metaclass=abc.ABCMeta): + """Base class for subject token suppliers. This can be implemented with custom logic to retrieve + a subject token to exchange for a Google Cloud access token when using Workload or + Workforce Identity Federation. The identity pool credential does not cache the subject token, + so caching logic should be added in the implementation. + """ + + @abc.abstractmethod + def get_subject_token(self, context, request): + """Returns the requested subject token. The subject token must be valid. + + .. warning: This is not cached by the calling Google credential, so caching logic should be implemented in the supplier. + + Args: + context (google.auth.externalaccount.SupplierContext): The context object + containing information about the requested audience and subject token type. + request (google.auth.transport.Request): The object used to make + HTTP requests. + + Raises: + google.auth.exceptions.RefreshError: If an error is encountered during + subject token retrieval logic. + + Returns: + str: The requested subject token string. + """ + raise NotImplementedError("") + + +class _TokenContent(NamedTuple): + """Models the token content response from file and url internal suppliers. + Attributes: + content (str): The string content of the file or URL response. + location (str): The location the content was retrieved from. This will either be a file location or a URL. + """ + + content: str + location: str + + +class _FileSupplier(SubjectTokenSupplier): + """ Internal implementation of subject token supplier which supports reading a subject token from a file.""" + + def __init__(self, path, format_type, subject_token_field_name): + self._path = path + self._format_type = format_type + self._subject_token_field_name = subject_token_field_name + + @_helpers.copy_docstring(SubjectTokenSupplier) + def get_subject_token(self, context, request): + if not os.path.exists(self._path): + raise exceptions.RefreshError("File '{}' was not found.".format(self._path)) + + with io.open(self._path, "r", encoding="utf-8") as file_obj: + token_content = _TokenContent(file_obj.read(), self._path) + + return _parse_token_data( + token_content, self._format_type, self._subject_token_field_name + ) + + +class _UrlSupplier(SubjectTokenSupplier): + """ Internal implementation of subject token supplier which supports retrieving a subject token by calling a URL endpoint.""" + + def __init__(self, url, format_type, subject_token_field_name, headers): + self._url = url + self._format_type = format_type + self._subject_token_field_name = subject_token_field_name + self._headers = headers + + @_helpers.copy_docstring(SubjectTokenSupplier) + def get_subject_token(self, context, request): + response = request(url=self._url, method="GET", headers=self._headers) + + # support both string and bytes type response.data + response_body = ( + response.data.decode("utf-8") + if hasattr(response.data, "decode") + else response.data + ) + + if response.status != 200: + raise exceptions.RefreshError( + "Unable to retrieve Identity Pool subject token", response_body + ) + token_content = _TokenContent(response_body, self._url) + return _parse_token_data( + token_content, self._format_type, self._subject_token_field_name + ) + + +def _parse_token_data(token_content, format_type="text", subject_token_field_name=None): + if format_type == "text": + token = token_content.content + else: + try: + # Parse file content as JSON. + response_data = json.loads(token_content.content) + # Get the subject_token. + token = response_data[subject_token_field_name] + except (KeyError, ValueError): + raise exceptions.RefreshError( + "Unable to parse subject_token from JSON file '{}' using key '{}'".format( + token_content.location, subject_token_field_name + ) + ) + if not token: + raise exceptions.RefreshError( + "Missing subject_token in the credential_source file" + ) + return token + + class Credentials(external_account.Credentials): """External account credentials sourced from files and URLs.""" @@ -54,8 +171,9 @@ def __init__( self, audience, subject_token_type, - token_url, - credential_source, + token_url=external_account._DEFAULT_TOKEN_URL, + credential_source=None, + subject_token_supplier=None, *args, **kwargs ): @@ -63,11 +181,18 @@ def __init__( Args: audience (str): The STS audience field. - subject_token_type (str): The subject token type. - token_url (str): The STS endpoint URL. - credential_source (Mapping): The credential source dictionary used to + subject_token_type (str): The subject token type based on the Oauth2.0 token exchange spec. + Expected values include:: + + “urn:ietf:params:oauth:token-type:jwt” + “urn:ietf:params:oauth:token-type:id-token” + “urn:ietf:params:oauth:token-type:saml2” + + token_url (Optional [str]): The STS endpoint URL. If not provided, will default to "https://sts.googleapis.com/v1/token". + credential_source (Optional [Mapping]): The credential source dictionary used to provide instructions on how to retrieve external credential to be - exchanged for Google access tokens. + exchanged for Google access tokens. Either a credential source or + a subject token supplier must be provided. Example credential_source for url-sourced credential:: @@ -85,6 +210,10 @@ def __init__( { "file": "/path/to/token/file.txt" } + subject_token_supplier (Optional [SubjectTokenSupplier]): Optional subject token supplier. + This will be called to supply a valid subject token which will then + be exchanged for Google access tokens. Either a subject token supplier + or a credential source must be provided. args (List): Optional positional arguments passed into the underlying :meth:`~external_account.Credentials.__init__` method. kwargs (Mapping): Optional keyword arguments passed into the underlying :meth:`~external_account.Credentials.__init__` method. @@ -106,10 +235,25 @@ def __init__( *args, **kwargs ) - if not isinstance(credential_source, Mapping): + if credential_source is None and subject_token_supplier is None: + raise exceptions.InvalidValue( + "A valid credential source or a subject token supplier must be provided." + ) + if credential_source is not None and subject_token_supplier is not None: + raise exceptions.InvalidValue( + "Identity pool credential cannot have both a credential source and a subject token supplier." + ) + + if subject_token_supplier is not None: + self._subject_token_supplier = subject_token_supplier self._credential_source_file = None self._credential_source_url = None else: + if not isinstance(credential_source, Mapping): + self._credential_source_executable = None + raise exceptions.MalformedError( + "Invalid credential_source. The credential_source is not a dict." + ) self._credential_source_file = credential_source.get("file") self._credential_source_url = credential_source.get("url") self._credential_source_headers = credential_source.get("headers") @@ -143,79 +287,35 @@ def __init__( else: self._credential_source_field_name = None - if self._credential_source_file and self._credential_source_url: - raise exceptions.MalformedError( - "Ambiguous credential_source. 'file' is mutually exclusive with 'url'." - ) - if not self._credential_source_file and not self._credential_source_url: - raise exceptions.MalformedError( - "Missing credential_source. A 'file' or 'url' must be provided." - ) + if self._credential_source_file and self._credential_source_url: + raise exceptions.MalformedError( + "Ambiguous credential_source. 'file' is mutually exclusive with 'url'." + ) + if not self._credential_source_file and not self._credential_source_url: + raise exceptions.MalformedError( + "Missing credential_source. A 'file' or 'url' must be provided." + ) + + if self._credential_source_file: + self._subject_token_supplier = _FileSupplier( + self._credential_source_file, + self._credential_source_format_type, + self._credential_source_field_name, + ) + else: + self._subject_token_supplier = _UrlSupplier( + self._credential_source_url, + self._credential_source_format_type, + self._credential_source_field_name, + self._credential_source_headers, + ) @_helpers.copy_docstring(external_account.Credentials) def retrieve_subject_token(self, request): - return self._parse_token_data( - self._get_token_data(request), - self._credential_source_format_type, - self._credential_source_field_name, - ) - - def _get_token_data(self, request): - if self._credential_source_file: - return self._get_file_data(self._credential_source_file) - else: - return self._get_url_data( - request, self._credential_source_url, self._credential_source_headers - ) - - def _get_file_data(self, filename): - if not os.path.exists(filename): - raise exceptions.RefreshError("File '{}' was not found.".format(filename)) - - with io.open(filename, "r", encoding="utf-8") as file_obj: - return file_obj.read(), filename - - def _get_url_data(self, request, url, headers): - response = request(url=url, method="GET", headers=headers) - - # support both string and bytes type response.data - response_body = ( - response.data.decode("utf-8") - if hasattr(response.data, "decode") - else response.data + return self._subject_token_supplier.get_subject_token( + self._supplier_context, request ) - if response.status != 200: - raise exceptions.RefreshError( - "Unable to retrieve Identity Pool subject token", response_body - ) - - return response_body, url - - def _parse_token_data( - self, token_content, format_type="text", subject_token_field_name=None - ): - content, filename = token_content - if format_type == "text": - token = content - else: - try: - # Parse file content as JSON. - response_data = json.loads(content) - # Get the subject_token. - token = response_data[subject_token_field_name] - except (KeyError, ValueError): - raise exceptions.RefreshError( - "Unable to parse subject_token from JSON file '{}' using key '{}'".format( - filename, subject_token_field_name - ) - ) - if not token: - raise exceptions.RefreshError( - "Missing subject_token in the credential_source file" - ) - return token - def _create_default_metrics_options(self): metrics_options = super(Credentials, self)._create_default_metrics_options() # Check that credential source is a dict before checking for file vs url. This check needs to be done @@ -226,8 +326,20 @@ def _create_default_metrics_options(self): metrics_options["source"] = "file" else: metrics_options["source"] = "url" + else: + metrics_options["source"] = "programmatic" return metrics_options + def _has_custom_supplier(self): + return self._credential_source is None + + def _constructor_args(self): + args = super(Credentials, self)._constructor_args() + # If a custom supplier was used, append it to the args dict. + if self._has_custom_supplier(): + args.update({"subject_token_supplier": self._subject_token_supplier}) + return args + @classmethod def from_info(cls, info, **kwargs): """Creates an Identity Pool Credentials instance from parsed external account info. @@ -244,6 +356,8 @@ def from_info(cls, info, **kwargs): Raises: ValueError: For invalid parameters. """ + subject_token_supplier = info.get("subject_token_supplier") + kwargs.update({"subject_token_supplier": subject_token_supplier}) return super(Credentials, cls).from_info(info, **kwargs) @classmethod diff --git a/tests/test_aws.py b/tests/test_aws.py index 3f358d52b..561482031 100644 --- a/tests/test_aws.py +++ b/tests/test_aws.py @@ -21,7 +21,7 @@ import mock import pytest # type: ignore -from google.auth import _helpers +from google.auth import _helpers, external_account from google.auth import aws from google.auth import environment_vars from google.auth import exceptions @@ -616,8 +616,13 @@ def test_get_request_options( ): utcnow.return_value = datetime.datetime.strptime(time, "%Y-%m-%dT%H:%M:%SZ") request_signer = aws.RequestSigner(region) + credentials_object = aws.AwsSecurityCredentials( + credentials.get("access_key_id"), + credentials.get("secret_access_key"), + credentials.get("security_token"), + ) actual_signed_request = request_signer.get_request_options( - credentials, + credentials_object, original_request.get("url"), original_request.get("method"), original_request.get("data"), @@ -631,10 +636,7 @@ def test_get_request_options_with_missing_scheme_url(self): with pytest.raises(ValueError) as excinfo: request_signer.get_request_options( - { - "access_key_id": ACCESS_KEY_ID, - "secret_access_key": SECRET_ACCESS_KEY, - }, + aws.AwsSecurityCredentials(ACCESS_KEY_ID, SECRET_ACCESS_KEY), "invalid", "POST", ) @@ -646,10 +648,7 @@ def test_get_request_options_with_invalid_scheme_url(self): with pytest.raises(ValueError) as excinfo: request_signer.get_request_options( - { - "access_key_id": ACCESS_KEY_ID, - "secret_access_key": SECRET_ACCESS_KEY, - }, + aws.AwsSecurityCredentials(ACCESS_KEY_ID, SECRET_ACCESS_KEY), "http://invalid", "POST", ) @@ -661,10 +660,7 @@ def test_get_request_options_with_missing_hostname_url(self): with pytest.raises(ValueError) as excinfo: request_signer.get_request_options( - { - "access_key_id": ACCESS_KEY_ID, - "secret_access_key": SECRET_ACCESS_KEY, - }, + aws.AwsSecurityCredentials(ACCESS_KEY_ID, SECRET_ACCESS_KEY), "https://", "POST", ) @@ -672,6 +668,36 @@ def test_get_request_options_with_missing_hostname_url(self): assert excinfo.match(r"Invalid AWS service URL") +class TestAwsSecurityCredentialsSupplier(aws.AwsSecurityCredentialsSupplier): + def __init__( + self, + security_credentials=None, + region=None, + credentials_exception=None, + region_exception=None, + expected_context=None, + ): + self._security_credentials = security_credentials + self._region = region + self._credentials_exception = credentials_exception + self._region_exception = region_exception + self._expected_context = expected_context + + def get_aws_security_credentials(self, context, request): + if self._expected_context is not None: + assert self._expected_context == context + if self._credentials_exception is not None: + raise self._credentials_exception + return self._security_credentials + + def get_aws_region(self, context, request): + if self._expected_context is not None: + assert self._expected_context == context + if self._region_exception is not None: + raise self._region_exception + return self._region + + class TestCredentials(object): AWS_REGION = "us-east-2" AWS_ROLE = "gcp-aws-role" @@ -734,7 +760,7 @@ def make_serialized_aws_signed_request( ], } # Include security token if available. - if "security_token" in aws_security_credentials: + if aws_security_credentials.session_token is not None: reformatted_signed_request.get("headers").append( { "key": "x-amz-security-token", @@ -773,16 +799,17 @@ def make_mock_request( in an AWS environment. """ responses = [] - if imdsv2_session_token_status: - # AWS session token request - imdsv2_session_response = mock.create_autospec( - transport.Response, instance=True - ) - imdsv2_session_response.status = imdsv2_session_token_status - imdsv2_session_response.data = imdsv2_session_token_data - responses.append(imdsv2_session_response) if region_status: + if imdsv2_session_token_status: + # AWS session token request + imdsv2_session_response = mock.create_autospec( + transport.Response, instance=True + ) + imdsv2_session_response.status = imdsv2_session_token_status + imdsv2_session_response.data = imdsv2_session_token_data + responses.append(imdsv2_session_response) + # AWS region request. region_response = mock.create_autospec(transport.Response, instance=True) region_response.status = region_status @@ -790,6 +817,15 @@ def make_mock_request( region_response.data = "{}b".format(region_name).encode("utf-8") responses.append(region_response) + if imdsv2_session_token_status: + # AWS session token request + imdsv2_session_response = mock.create_autospec( + transport.Response, instance=True + ) + imdsv2_session_response.status = imdsv2_session_token_status + imdsv2_session_response.data = imdsv2_session_token_data + responses.append(imdsv2_session_response) + if role_status: # AWS role name request. role_response = mock.create_autospec(transport.Response, instance=True) @@ -834,7 +870,8 @@ def make_mock_request( @classmethod def make_credentials( cls, - credential_source, + credential_source=None, + aws_security_credentials_supplier=None, token_url=TOKEN_URL, token_info_url=TOKEN_INFO_URL, client_id=None, @@ -851,6 +888,7 @@ def make_credentials( token_info_url=token_info_url, service_account_impersonation_url=service_account_impersonation_url, credential_source=credential_source, + aws_security_credentials_supplier=aws_security_credentials_supplier, client_id=client_id, client_secret=client_secret, quota_project_id=quota_project_id, @@ -929,6 +967,7 @@ def test_from_info_full_options(self, mock_init): client_id=CLIENT_ID, client_secret=CLIENT_SECRET, credential_source=self.CREDENTIAL_SOURCE, + aws_security_credentials_supplier=None, quota_project_id=QUOTA_PROJECT_ID, workforce_pool_user_project=None, universe_domain=DEFAULT_UNIVERSE_DOMAIN, @@ -957,6 +996,38 @@ def test_from_info_required_options_only(self, mock_init): client_id=None, client_secret=None, credential_source=self.CREDENTIAL_SOURCE, + aws_security_credentials_supplier=None, + quota_project_id=None, + workforce_pool_user_project=None, + universe_domain=DEFAULT_UNIVERSE_DOMAIN, + ) + + @mock.patch.object(aws.Credentials, "__init__", return_value=None) + def test_from_info_supplier(self, mock_init): + supplier = TestAwsSecurityCredentialsSupplier() + + credentials = aws.Credentials.from_info( + { + "audience": AUDIENCE, + "subject_token_type": SUBJECT_TOKEN_TYPE, + "token_url": TOKEN_URL, + "aws_security_credentials_supplier": supplier, + } + ) + + # Confirm aws.Credentials instance initialized with the expected parameters. + assert isinstance(credentials, aws.Credentials) + mock_init.assert_called_once_with( + audience=AUDIENCE, + subject_token_type=SUBJECT_TOKEN_TYPE, + token_url=TOKEN_URL, + token_info_url=None, + service_account_impersonation_url=None, + service_account_impersonation_options={}, + client_id=None, + client_secret=None, + credential_source=None, + aws_security_credentials_supplier=supplier, quota_project_id=None, workforce_pool_user_project=None, universe_domain=DEFAULT_UNIVERSE_DOMAIN, @@ -993,6 +1064,7 @@ def test_from_file_full_options(self, mock_init, tmpdir): client_id=CLIENT_ID, client_secret=CLIENT_SECRET, credential_source=self.CREDENTIAL_SOURCE, + aws_security_credentials_supplier=None, quota_project_id=QUOTA_PROJECT_ID, workforce_pool_user_project=None, universe_domain=DEFAULT_UNIVERSE_DOMAIN, @@ -1022,6 +1094,7 @@ def test_from_file_required_options_only(self, mock_init, tmpdir): client_id=None, client_secret=None, credential_source=self.CREDENTIAL_SOURCE, + aws_security_credentials_supplier=None, quota_project_id=None, workforce_pool_user_project=None, universe_domain=DEFAULT_UNIVERSE_DOMAIN, @@ -1036,6 +1109,27 @@ def test_constructor_invalid_credential_source(self): assert excinfo.match(r"No valid AWS 'credential_source' provided") + def test_constructor_invalid_credential_source_and_supplier(self): + # Provide both a credential source and supplier. + with pytest.raises(ValueError) as excinfo: + self.make_credentials( + credential_source=self.CREDENTIAL_SOURCE, + aws_security_credentials_supplier="test", + ) + + assert excinfo.match( + r"AWS credential cannot have both a credential source and an AWS security credentials supplier." + ) + + def test_constructor_invalid_no_credential_source_or_supplier(self): + # Provide no credential source or supplier. + with pytest.raises(ValueError) as excinfo: + self.make_credentials() + + assert excinfo.match( + r"A valid credential source or AWS security credentials supplier must be provided." + ) + def test_constructor_invalid_environment_id(self): # Provide invalid environment_id. credential_source = self.CREDENTIAL_SOURCE.copy() @@ -1158,11 +1252,7 @@ def test_retrieve_subject_token_success_temp_creds_no_environment_vars( subject_token = credentials.retrieve_subject_token(request) assert subject_token == self.make_serialized_aws_signed_request( - { - "access_key_id": ACCESS_KEY_ID, - "secret_access_key": SECRET_ACCESS_KEY, - "security_token": TOKEN, - } + aws.AwsSecurityCredentials(ACCESS_KEY_ID, SECRET_ACCESS_KEY, TOKEN) ) # Assert region request. self.assert_aws_metadata_request_kwargs( @@ -1231,11 +1321,7 @@ def test_retrieve_subject_token_success_temp_creds_no_environment_vars_idmsv2( subject_token = credentials.retrieve_subject_token(request) assert subject_token == self.make_serialized_aws_signed_request( - { - "access_key_id": ACCESS_KEY_ID, - "secret_access_key": SECRET_ACCESS_KEY, - "security_token": TOKEN, - } + aws.AwsSecurityCredentials(ACCESS_KEY_ID, SECRET_ACCESS_KEY, TOKEN) ) # Assert session token request self.assert_aws_metadata_request_kwargs( @@ -1250,15 +1336,22 @@ def test_retrieve_subject_token_success_temp_creds_no_environment_vars_idmsv2( REGION_URL, {"X-aws-ec2-metadata-token": self.AWS_IMDSV2_SESSION_TOKEN}, ) - # Assert role request. + # Assert session token request self.assert_aws_metadata_request_kwargs( request.call_args_list[2][1], + IMDSV2_SESSION_TOKEN_URL, + {"X-aws-ec2-metadata-token-ttl-seconds": "300"}, + "PUT", + ) + # Assert role request. + self.assert_aws_metadata_request_kwargs( + request.call_args_list[3][1], SECURITY_CREDS_URL, {"X-aws-ec2-metadata-token": self.AWS_IMDSV2_SESSION_TOKEN}, ) # Assert security credentials request. self.assert_aws_metadata_request_kwargs( - request.call_args_list[3][1], + request.call_args_list[4][1], "{}/{}".format(SECURITY_CREDS_URL, self.AWS_ROLE), { "Content-Type": "application/json", @@ -1335,11 +1428,7 @@ def test_retrieve_subject_token_success_temp_creds_environment_vars_missing_secr subject_token = credentials.retrieve_subject_token(request) assert subject_token == self.make_serialized_aws_signed_request( - { - "access_key_id": ACCESS_KEY_ID, - "secret_access_key": SECRET_ACCESS_KEY, - "security_token": TOKEN, - } + aws.AwsSecurityCredentials(ACCESS_KEY_ID, SECRET_ACCESS_KEY, TOKEN) ) # Assert session token request. self.assert_aws_metadata_request_kwargs( @@ -1396,11 +1485,7 @@ def test_retrieve_subject_token_success_temp_creds_environment_vars_missing_acce subject_token = credentials.retrieve_subject_token(request) assert subject_token == self.make_serialized_aws_signed_request( - { - "access_key_id": ACCESS_KEY_ID, - "secret_access_key": SECRET_ACCESS_KEY, - "security_token": TOKEN, - } + aws.AwsSecurityCredentials(ACCESS_KEY_ID, SECRET_ACCESS_KEY, TOKEN) ) # Assert session token request. self.assert_aws_metadata_request_kwargs( @@ -1451,11 +1536,7 @@ def test_retrieve_subject_token_success_temp_creds_environment_vars_missing_cred subject_token = credentials.retrieve_subject_token(request) assert subject_token == self.make_serialized_aws_signed_request( - { - "access_key_id": ACCESS_KEY_ID, - "secret_access_key": SECRET_ACCESS_KEY, - "security_token": TOKEN, - } + aws.AwsSecurityCredentials(ACCESS_KEY_ID, SECRET_ACCESS_KEY, TOKEN) ) # Assert session token request. self.assert_aws_metadata_request_kwargs( @@ -1530,11 +1611,7 @@ def test_retrieve_subject_token_success_ipv6(self, utcnow): subject_token = credentials.retrieve_subject_token(request) assert subject_token == self.make_serialized_aws_signed_request( - { - "access_key_id": ACCESS_KEY_ID, - "secret_access_key": SECRET_ACCESS_KEY, - "security_token": TOKEN, - } + aws.AwsSecurityCredentials(ACCESS_KEY_ID, SECRET_ACCESS_KEY, TOKEN) ) # Assert session token request. self.assert_aws_metadata_request_kwargs( @@ -1549,15 +1626,22 @@ def test_retrieve_subject_token_success_ipv6(self, utcnow): REGION_URL_IPV6, {"X-aws-ec2-metadata-token": self.AWS_IMDSV2_SESSION_TOKEN}, ) - # Assert role request. + # Assert session token request. self.assert_aws_metadata_request_kwargs( request.call_args_list[2][1], + IMDSV2_SESSION_TOKEN_URL_IPV6, + {"X-aws-ec2-metadata-token-ttl-seconds": "300"}, + "PUT", + ) + # Assert role request. + self.assert_aws_metadata_request_kwargs( + request.call_args_list[3][1], SECURITY_CREDS_URL_IPV6, {"X-aws-ec2-metadata-token": self.AWS_IMDSV2_SESSION_TOKEN}, ) # Assert security credentials request. self.assert_aws_metadata_request_kwargs( - request.call_args_list[3][1], + request.call_args_list[4][1], "{}/{}".format(SECURITY_CREDS_URL_IPV6, self.AWS_ROLE), { "Content-Type": "application/json", @@ -1619,7 +1703,7 @@ def test_retrieve_subject_token_success_permanent_creds_no_environment_vars( subject_token = credentials.retrieve_subject_token(request) assert subject_token == self.make_serialized_aws_signed_request( - {"access_key_id": ACCESS_KEY_ID, "secret_access_key": SECRET_ACCESS_KEY} + aws.AwsSecurityCredentials(ACCESS_KEY_ID, SECRET_ACCESS_KEY) ) @mock.patch("google.auth._helpers.utcnow") @@ -1636,11 +1720,7 @@ def test_retrieve_subject_token_success_environment_vars(self, utcnow, monkeypat subject_token = credentials.retrieve_subject_token(None) assert subject_token == self.make_serialized_aws_signed_request( - { - "access_key_id": ACCESS_KEY_ID, - "secret_access_key": SECRET_ACCESS_KEY, - "security_token": TOKEN, - } + aws.AwsSecurityCredentials(ACCESS_KEY_ID, SECRET_ACCESS_KEY, TOKEN) ) @mock.patch("google.auth._helpers.utcnow") @@ -1659,11 +1739,7 @@ def test_retrieve_subject_token_success_environment_vars_with_default_region( subject_token = credentials.retrieve_subject_token(None) assert subject_token == self.make_serialized_aws_signed_request( - { - "access_key_id": ACCESS_KEY_ID, - "secret_access_key": SECRET_ACCESS_KEY, - "security_token": TOKEN, - } + aws.AwsSecurityCredentials(ACCESS_KEY_ID, SECRET_ACCESS_KEY, TOKEN) ) @mock.patch("google.auth._helpers.utcnow") @@ -1686,11 +1762,7 @@ def test_retrieve_subject_token_success_environment_vars_with_both_regions_set( subject_token = credentials.retrieve_subject_token(None) assert subject_token == self.make_serialized_aws_signed_request( - { - "access_key_id": ACCESS_KEY_ID, - "secret_access_key": SECRET_ACCESS_KEY, - "security_token": TOKEN, - } + aws.AwsSecurityCredentials(ACCESS_KEY_ID, SECRET_ACCESS_KEY, TOKEN) ) @mock.patch("google.auth._helpers.utcnow") @@ -1708,7 +1780,7 @@ def test_retrieve_subject_token_success_environment_vars_no_session_token( subject_token = credentials.retrieve_subject_token(None) assert subject_token == self.make_serialized_aws_signed_request( - {"access_key_id": ACCESS_KEY_ID, "secret_access_key": SECRET_ACCESS_KEY} + aws.AwsSecurityCredentials(ACCESS_KEY_ID, SECRET_ACCESS_KEY) ) @mock.patch("google.auth._helpers.utcnow") @@ -1730,11 +1802,7 @@ def test_retrieve_subject_token_success_environment_vars_except_region( subject_token = credentials.retrieve_subject_token(request) assert subject_token == self.make_serialized_aws_signed_request( - { - "access_key_id": ACCESS_KEY_ID, - "secret_access_key": SECRET_ACCESS_KEY, - "security_token": TOKEN, - } + aws.AwsSecurityCredentials(ACCESS_KEY_ID, SECRET_ACCESS_KEY, TOKEN) ) def test_retrieve_subject_token_error_determining_aws_region(self): @@ -1806,11 +1874,7 @@ def test_refresh_success_without_impersonation_ignore_default_scopes( self.AWS_SIGNATURE_TIME, "%Y-%m-%dT%H:%M:%SZ" ) expected_subject_token = self.make_serialized_aws_signed_request( - { - "access_key_id": ACCESS_KEY_ID, - "secret_access_key": SECRET_ACCESS_KEY, - "security_token": TOKEN, - } + aws.AwsSecurityCredentials(ACCESS_KEY_ID, SECRET_ACCESS_KEY, TOKEN) ) token_headers = { "Content-Type": "application/x-www-form-urlencoded", @@ -1869,11 +1933,7 @@ def test_refresh_success_without_impersonation_use_default_scopes( self.AWS_SIGNATURE_TIME, "%Y-%m-%dT%H:%M:%SZ" ) expected_subject_token = self.make_serialized_aws_signed_request( - { - "access_key_id": ACCESS_KEY_ID, - "secret_access_key": SECRET_ACCESS_KEY, - "security_token": TOKEN, - } + aws.AwsSecurityCredentials(ACCESS_KEY_ID, SECRET_ACCESS_KEY, TOKEN) ) token_headers = { "Content-Type": "application/x-www-form-urlencoded", @@ -1939,11 +1999,7 @@ def test_refresh_success_with_impersonation_ignore_default_scopes( _helpers.utcnow().replace(microsecond=0) + datetime.timedelta(seconds=3600) ).isoformat("T") + "Z" expected_subject_token = self.make_serialized_aws_signed_request( - { - "access_key_id": ACCESS_KEY_ID, - "secret_access_key": SECRET_ACCESS_KEY, - "security_token": TOKEN, - } + aws.AwsSecurityCredentials(ACCESS_KEY_ID, SECRET_ACCESS_KEY, TOKEN) ) token_headers = { "Content-Type": "application/x-www-form-urlencoded", @@ -2036,11 +2092,7 @@ def test_refresh_success_with_impersonation_use_default_scopes( _helpers.utcnow().replace(microsecond=0) + datetime.timedelta(seconds=3600) ).isoformat("T") + "Z" expected_subject_token = self.make_serialized_aws_signed_request( - { - "access_key_id": ACCESS_KEY_ID, - "secret_access_key": SECRET_ACCESS_KEY, - "security_token": TOKEN, - } + aws.AwsSecurityCredentials(ACCESS_KEY_ID, SECRET_ACCESS_KEY, TOKEN) ) token_headers = { "Content-Type": "application/x-www-form-urlencoded", @@ -2122,3 +2174,249 @@ def test_refresh_with_retrieve_subject_token_error(self): credentials.refresh(request) assert excinfo.match(r"Unable to retrieve AWS region") + + @mock.patch("google.auth._helpers.utcnow") + def test_retrieve_subject_token_success_with_supplier(self, utcnow): + utcnow.return_value = datetime.datetime.strptime( + self.AWS_SIGNATURE_TIME, "%Y-%m-%dT%H:%M:%SZ" + ) + request = self.make_mock_request() + + security_credentials = aws.AwsSecurityCredentials( + ACCESS_KEY_ID, SECRET_ACCESS_KEY + ) + supplier = TestAwsSecurityCredentialsSupplier( + security_credentials=security_credentials, region=self.AWS_REGION + ) + + credentials = self.make_credentials(aws_security_credentials_supplier=supplier) + + subject_token = credentials.retrieve_subject_token(request) + assert subject_token == self.make_serialized_aws_signed_request( + aws.AwsSecurityCredentials(ACCESS_KEY_ID, SECRET_ACCESS_KEY) + ) + + @mock.patch("google.auth._helpers.utcnow") + def test_retrieve_subject_token_success_with_supplier_session_token(self, utcnow): + utcnow.return_value = datetime.datetime.strptime( + self.AWS_SIGNATURE_TIME, "%Y-%m-%dT%H:%M:%SZ" + ) + request = self.make_mock_request() + + security_credentials = aws.AwsSecurityCredentials( + ACCESS_KEY_ID, SECRET_ACCESS_KEY, TOKEN + ) + supplier = TestAwsSecurityCredentialsSupplier( + security_credentials=security_credentials, region=self.AWS_REGION + ) + + credentials = self.make_credentials(aws_security_credentials_supplier=supplier) + + subject_token = credentials.retrieve_subject_token(request) + assert subject_token == self.make_serialized_aws_signed_request( + aws.AwsSecurityCredentials(ACCESS_KEY_ID, SECRET_ACCESS_KEY, TOKEN) + ) + + @mock.patch("google.auth._helpers.utcnow") + def test_retrieve_subject_token_success_with_supplier_correct_context(self, utcnow): + utcnow.return_value = datetime.datetime.strptime( + self.AWS_SIGNATURE_TIME, "%Y-%m-%dT%H:%M:%SZ" + ) + request = self.make_mock_request() + expected_context = external_account.SupplierContext( + SUBJECT_TOKEN_TYPE, AUDIENCE + ) + + security_credentials = aws.AwsSecurityCredentials( + ACCESS_KEY_ID, SECRET_ACCESS_KEY + ) + supplier = TestAwsSecurityCredentialsSupplier( + security_credentials=security_credentials, + region=self.AWS_REGION, + expected_context=expected_context, + ) + + credentials = self.make_credentials(aws_security_credentials_supplier=supplier) + + credentials.retrieve_subject_token(request) + + def test_retrieve_subject_token_error_with_supplier(self): + request = self.make_mock_request() + expected_exception = exceptions.RefreshError("Test error") + supplier = TestAwsSecurityCredentialsSupplier( + region=self.AWS_REGION, credentials_exception=expected_exception + ) + + credentials = self.make_credentials(aws_security_credentials_supplier=supplier) + + with pytest.raises(exceptions.RefreshError) as excinfo: + credentials.refresh(request) + + assert excinfo.match(r"Test error") + + def test_retrieve_subject_token_error_with_supplier_region(self): + request = self.make_mock_request() + expected_exception = exceptions.RefreshError("Test error") + security_credentials = aws.AwsSecurityCredentials( + ACCESS_KEY_ID, SECRET_ACCESS_KEY + ) + supplier = TestAwsSecurityCredentialsSupplier( + security_credentials=security_credentials, + region_exception=expected_exception, + ) + + credentials = self.make_credentials(aws_security_credentials_supplier=supplier) + + with pytest.raises(exceptions.RefreshError) as excinfo: + credentials.refresh(request) + + assert excinfo.match(r"Test error") + + @mock.patch( + "google.auth.metrics.python_and_auth_lib_version", + return_value=LANG_LIBRARY_METRICS_HEADER_VALUE, + ) + @mock.patch("google.auth._helpers.utcnow") + def test_refresh_success_with_supplier_with_impersonation( + self, utcnow, mock_auth_lib_value + ): + utcnow.return_value = datetime.datetime.strptime( + self.AWS_SIGNATURE_TIME, "%Y-%m-%dT%H:%M:%SZ" + ) + expire_time = ( + _helpers.utcnow().replace(microsecond=0) + datetime.timedelta(seconds=3600) + ).isoformat("T") + "Z" + expected_subject_token = self.make_serialized_aws_signed_request( + aws.AwsSecurityCredentials(ACCESS_KEY_ID, SECRET_ACCESS_KEY, TOKEN) + ) + token_headers = { + "Content-Type": "application/x-www-form-urlencoded", + "Authorization": "Basic " + BASIC_AUTH_ENCODING, + "x-goog-api-client": "gl-python/3.7 auth/1.1 google-byoid-sdk sa-impersonation/true config-lifetime/false source/programmatic", + } + token_request_data = { + "grant_type": "urn:ietf:params:oauth:grant-type:token-exchange", + "audience": AUDIENCE, + "requested_token_type": "urn:ietf:params:oauth:token-type:access_token", + "scope": "https://www.googleapis.com/auth/iam", + "subject_token": expected_subject_token, + "subject_token_type": SUBJECT_TOKEN_TYPE, + } + # Service account impersonation request/response. + impersonation_response = { + "accessToken": "SA_ACCESS_TOKEN", + "expireTime": expire_time, + } + impersonation_headers = { + "Content-Type": "application/json", + "authorization": "Bearer {}".format(self.SUCCESS_RESPONSE["access_token"]), + "x-goog-user-project": QUOTA_PROJECT_ID, + "x-goog-api-client": IMPERSONATE_ACCESS_TOKEN_REQUEST_METRICS_HEADER_VALUE, + "x-allowed-locations": "0x0", + } + impersonation_request_data = { + "delegates": None, + "scope": SCOPES, + "lifetime": "3600s", + } + request = self.make_mock_request( + token_status=http_client.OK, + token_data=self.SUCCESS_RESPONSE, + impersonation_status=http_client.OK, + impersonation_data=impersonation_response, + ) + + supplier = TestAwsSecurityCredentialsSupplier( + security_credentials=aws.AwsSecurityCredentials( + ACCESS_KEY_ID, SECRET_ACCESS_KEY, TOKEN + ), + region=self.AWS_REGION, + ) + + credentials = self.make_credentials( + client_id=CLIENT_ID, + client_secret=CLIENT_SECRET, + aws_security_credentials_supplier=supplier, + service_account_impersonation_url=SERVICE_ACCOUNT_IMPERSONATION_URL, + quota_project_id=QUOTA_PROJECT_ID, + scopes=SCOPES, + # Default scopes should be ignored. + default_scopes=["ignored"], + ) + + credentials.refresh(request) + + assert len(request.call_args_list) == 2 + # First request should be sent to GCP STS endpoint. + self.assert_token_request_kwargs( + request.call_args_list[0][1], token_headers, token_request_data + ) + # Second request should be sent to iamcredentials endpoint for service + # account impersonation. + self.assert_impersonation_request_kwargs( + request.call_args_list[1][1], + impersonation_headers, + impersonation_request_data, + ) + assert credentials.token == impersonation_response["accessToken"] + assert credentials.quota_project_id == QUOTA_PROJECT_ID + assert credentials.scopes == SCOPES + assert credentials.default_scopes == ["ignored"] + + @mock.patch( + "google.auth.metrics.python_and_auth_lib_version", + return_value=LANG_LIBRARY_METRICS_HEADER_VALUE, + ) + @mock.patch("google.auth._helpers.utcnow") + def test_refresh_success_with_supplier(self, utcnow, mock_auth_lib_value): + utcnow.return_value = datetime.datetime.strptime( + self.AWS_SIGNATURE_TIME, "%Y-%m-%dT%H:%M:%SZ" + ) + expected_subject_token = self.make_serialized_aws_signed_request( + aws.AwsSecurityCredentials(ACCESS_KEY_ID, SECRET_ACCESS_KEY, TOKEN) + ) + token_headers = { + "Content-Type": "application/x-www-form-urlencoded", + "Authorization": "Basic " + BASIC_AUTH_ENCODING, + "x-goog-api-client": "gl-python/3.7 auth/1.1 google-byoid-sdk sa-impersonation/false config-lifetime/false source/programmatic", + } + token_request_data = { + "grant_type": "urn:ietf:params:oauth:grant-type:token-exchange", + "audience": AUDIENCE, + "requested_token_type": "urn:ietf:params:oauth:token-type:access_token", + "scope": " ".join(SCOPES), + "subject_token": expected_subject_token, + "subject_token_type": SUBJECT_TOKEN_TYPE, + } + request = self.make_mock_request( + token_status=http_client.OK, token_data=self.SUCCESS_RESPONSE + ) + + supplier = TestAwsSecurityCredentialsSupplier( + security_credentials=aws.AwsSecurityCredentials( + ACCESS_KEY_ID, SECRET_ACCESS_KEY, TOKEN + ), + region=self.AWS_REGION, + ) + + credentials = self.make_credentials( + client_id=CLIENT_ID, + client_secret=CLIENT_SECRET, + aws_security_credentials_supplier=supplier, + quota_project_id=QUOTA_PROJECT_ID, + scopes=SCOPES, + # Default scopes should be ignored. + default_scopes=["ignored"], + ) + + credentials.refresh(request) + + assert len(request.call_args_list) == 1 + # First request should be sent to GCP STS endpoint. + self.assert_token_request_kwargs( + request.call_args_list[0][1], token_headers, token_request_data + ) + assert credentials.token == self.SUCCESS_RESPONSE["access_token"] + assert credentials.quota_project_id == QUOTA_PROJECT_ID + assert credentials.scopes == SCOPES + assert credentials.default_scopes == ["ignored"] diff --git a/tests/test_external_account.py b/tests/test_external_account.py index 03a5014ce..c458b21b6 100644 --- a/tests/test_external_account.py +++ b/tests/test_external_account.py @@ -477,16 +477,6 @@ def test_with_quota_project_full_options_propagated(self): universe_domain=DEFAULT_UNIVERSE_DOMAIN, ) - def test_with_invalid_impersonation_target_principal(self): - invalid_url = "https://iamcredentials.googleapis.com/v1/invalid" - - with pytest.raises(exceptions.RefreshError) as excinfo: - self.make_credentials(service_account_impersonation_url=invalid_url) - - assert excinfo.match( - r"Unable to determine target principal from service account impersonation URL." - ) - def test_info(self): credentials = self.make_credentials(universe_domain="dummy_universe.com") @@ -1069,6 +1059,21 @@ def test_refresh_impersonation_without_client_auth_error(self): assert not credentials.expired assert credentials.token is None + def test_refresh_impersonation_invalid_impersonated_url_error(self): + credentials = self.make_credentials( + service_account_impersonation_url="https://iamcredentials.googleapis.com/v1/invalid", + scopes=self.SCOPES, + ) + + with pytest.raises(exceptions.RefreshError) as excinfo: + credentials.refresh(None) + + assert excinfo.match( + r"Unable to determine target principal from service account impersonation URL." + ) + assert not credentials.expired + assert credentials.token is None + @mock.patch( "google.auth.metrics.python_and_auth_lib_version", return_value=LANG_LIBRARY_METRICS_HEADER_VALUE, @@ -1913,3 +1918,10 @@ def test_get_project_id_cloud_resource_manager_error(self): assert project_id is None # Only 2 requests to STS and cloud resource manager should be sent. assert len(request.call_args_list) == 2 + + +def test_supplier_context(): + context = external_account.SupplierContext("TestTokenType", "TestAudience") + + assert context.subject_token_type == "TestTokenType" + assert context.audience == "TestAudience" diff --git a/tests/test_identity_pool.py b/tests/test_identity_pool.py index be30c4e9b..ac1d9a0bb 100644 --- a/tests/test_identity_pool.py +++ b/tests/test_identity_pool.py @@ -21,7 +21,7 @@ import mock import pytest # type: ignore -from google.auth import _helpers +from google.auth import _helpers, external_account from google.auth import exceptions from google.auth import identity_pool from google.auth import metrics @@ -151,6 +151,22 @@ ] +class TestSubjectTokenSupplier(identity_pool.SubjectTokenSupplier): + def __init__( + self, subject_token=None, subject_token_exception=None, expected_context=None + ): + self._subject_token = subject_token + self._subject_token_exception = subject_token_exception + self._expected_context = expected_context + + def get_subject_token(self, context, request): + if self._expected_context is not None: + assert self._expected_context == context + if self._subject_token_exception is not None: + raise self._subject_token_exception + return self._subject_token + + class TestCredentials(object): CREDENTIAL_SOURCE_TEXT = {"file": SUBJECT_TOKEN_TEXT_FILE} CREDENTIAL_SOURCE_JSON = { @@ -273,10 +289,13 @@ def assert_underlying_credentials_refresh( else: metrics_options["sa-impersonation"] = "false" metrics_options["config-lifetime"] = "false" - if credentials._credential_source_file: - metrics_options["source"] = "file" + if credentials._credential_source: + if credentials._credential_source_file: + metrics_options["source"] = "file" + else: + metrics_options["source"] = "url" else: - metrics_options["source"] = "url" + metrics_options["source"] = "programmatic" token_headers["x-goog-api-client"] = metrics.byoid_metrics_header( metrics_options @@ -386,6 +405,7 @@ def make_credentials( default_scopes=None, service_account_impersonation_url=None, credential_source=None, + subject_token_supplier=None, workforce_pool_user_project=None, ): return identity_pool.Credentials( @@ -395,6 +415,7 @@ def make_credentials( token_info_url=token_info_url, service_account_impersonation_url=service_account_impersonation_url, credential_source=credential_source, + subject_token_supplier=subject_token_supplier, client_id=client_id, client_secret=client_secret, quota_project_id=quota_project_id, @@ -432,6 +453,7 @@ def test_from_info_full_options(self, mock_init): client_id=CLIENT_ID, client_secret=CLIENT_SECRET, credential_source=self.CREDENTIAL_SOURCE_TEXT, + subject_token_supplier=None, quota_project_id=QUOTA_PROJECT_ID, workforce_pool_user_project=None, universe_domain=DEFAULT_UNIVERSE_DOMAIN, @@ -460,6 +482,38 @@ def test_from_info_required_options_only(self, mock_init): client_id=None, client_secret=None, credential_source=self.CREDENTIAL_SOURCE_TEXT, + subject_token_supplier=None, + quota_project_id=None, + workforce_pool_user_project=None, + universe_domain=DEFAULT_UNIVERSE_DOMAIN, + ) + + @mock.patch.object(identity_pool.Credentials, "__init__", return_value=None) + def test_from_info_supplier(self, mock_init): + supplier = TestSubjectTokenSupplier() + + credentials = identity_pool.Credentials.from_info( + { + "audience": AUDIENCE, + "subject_token_type": SUBJECT_TOKEN_TYPE, + "token_url": TOKEN_URL, + "subject_token_supplier": supplier, + } + ) + + # Confirm identity_pool.Credentials instantiated with expected attributes. + assert isinstance(credentials, identity_pool.Credentials) + mock_init.assert_called_once_with( + audience=AUDIENCE, + subject_token_type=SUBJECT_TOKEN_TYPE, + token_url=TOKEN_URL, + token_info_url=None, + service_account_impersonation_url=None, + service_account_impersonation_options={}, + client_id=None, + client_secret=None, + credential_source=None, + subject_token_supplier=supplier, quota_project_id=None, workforce_pool_user_project=None, universe_domain=DEFAULT_UNIVERSE_DOMAIN, @@ -489,6 +543,7 @@ def test_from_info_workforce_pool(self, mock_init): client_id=None, client_secret=None, credential_source=self.CREDENTIAL_SOURCE_TEXT, + subject_token_supplier=None, quota_project_id=None, workforce_pool_user_project=WORKFORCE_POOL_USER_PROJECT, universe_domain=DEFAULT_UNIVERSE_DOMAIN, @@ -524,6 +579,7 @@ def test_from_file_full_options(self, mock_init, tmpdir): client_id=CLIENT_ID, client_secret=CLIENT_SECRET, credential_source=self.CREDENTIAL_SOURCE_TEXT, + subject_token_supplier=None, quota_project_id=QUOTA_PROJECT_ID, workforce_pool_user_project=None, universe_domain=DEFAULT_UNIVERSE_DOMAIN, @@ -553,6 +609,7 @@ def test_from_file_required_options_only(self, mock_init, tmpdir): client_id=None, client_secret=None, credential_source=self.CREDENTIAL_SOURCE_TEXT, + subject_token_supplier=None, quota_project_id=None, workforce_pool_user_project=None, universe_domain=DEFAULT_UNIVERSE_DOMAIN, @@ -583,6 +640,7 @@ def test_from_file_workforce_pool(self, mock_init, tmpdir): client_id=None, client_secret=None, credential_source=self.CREDENTIAL_SOURCE_TEXT, + subject_token_supplier=None, quota_project_id=None, workforce_pool_user_project=WORKFORCE_POOL_USER_PROJECT, universe_domain=DEFAULT_UNIVERSE_DOMAIN, @@ -633,7 +691,29 @@ def test_constructor_invalid_credential_source(self): with pytest.raises(ValueError) as excinfo: self.make_credentials(credential_source="non-dict") - assert excinfo.match(r"Missing credential_source") + assert excinfo.match( + r"Invalid credential_source. The credential_source is not a dict." + ) + + def test_constructor_invalid_no_credential_source_or_supplier(self): + with pytest.raises(ValueError) as excinfo: + self.make_credentials() + + assert excinfo.match( + r"A valid credential source or a subject token supplier must be provided." + ) + + def test_constructor_invalid_both_credential_source_and_supplier(self): + supplier = TestSubjectTokenSupplier() + with pytest.raises(ValueError) as excinfo: + self.make_credentials( + credential_source=self.CREDENTIAL_SOURCE_TEXT, + subject_token_supplier=supplier, + ) + + assert excinfo.match( + r"Identity pool credential cannot have both a credential source and a subject token supplier." + ) def test_constructor_invalid_credential_source_format_type(self): credential_source = {"format": {"type": "xml"}} @@ -1297,3 +1377,78 @@ def test_refresh_with_retrieve_subject_token_error_url(self): self.CREDENTIAL_URL, "not_found" ) ) + + def test_retrieve_subject_token_supplier(self): + supplier = TestSubjectTokenSupplier(subject_token=JSON_FILE_SUBJECT_TOKEN) + + credentials = self.make_credentials(subject_token_supplier=supplier) + + subject_token = credentials.retrieve_subject_token(None) + + assert subject_token == JSON_FILE_SUBJECT_TOKEN + + def test_retrieve_subject_token_supplier_correct_context(self): + supplier = TestSubjectTokenSupplier( + subject_token=JSON_FILE_SUBJECT_TOKEN, + expected_context=external_account.SupplierContext( + SUBJECT_TOKEN_TYPE, AUDIENCE + ), + ) + + credentials = self.make_credentials(subject_token_supplier=supplier) + + credentials.retrieve_subject_token(None) + + def test_retrieve_subject_token_supplier_error(self): + expected_exception = exceptions.RefreshError("test error") + supplier = TestSubjectTokenSupplier(subject_token_exception=expected_exception) + + credentials = self.make_credentials(subject_token_supplier=supplier) + + with pytest.raises(exceptions.RefreshError) as excinfo: + credentials.refresh(self.make_mock_request(token_data=JSON_FILE_CONTENT)) + + assert excinfo.match("test error") + + def test_refresh_success_supplier_with_impersonation_url(self): + # Initialize credentials with service account impersonation and a supplier. + supplier = TestSubjectTokenSupplier(subject_token=JSON_FILE_SUBJECT_TOKEN) + credentials = self.make_credentials( + subject_token_supplier=supplier, + service_account_impersonation_url=SERVICE_ACCOUNT_IMPERSONATION_URL, + scopes=SCOPES, + ) + + self.assert_underlying_credentials_refresh( + credentials=credentials, + audience=AUDIENCE, + subject_token=TEXT_FILE_SUBJECT_TOKEN, + subject_token_type=SUBJECT_TOKEN_TYPE, + token_url=TOKEN_URL, + service_account_impersonation_url=SERVICE_ACCOUNT_IMPERSONATION_URL, + basic_auth_encoding=None, + quota_project_id=None, + used_scopes=SCOPES, + scopes=SCOPES, + default_scopes=None, + ) + + def test_refresh_success_supplier_without_impersonation_url(self): + # Initialize supplier credentials without service account impersonation. + supplier = TestSubjectTokenSupplier(subject_token=JSON_FILE_SUBJECT_TOKEN) + credentials = self.make_credentials( + subject_token_supplier=supplier, scopes=SCOPES + ) + + self.assert_underlying_credentials_refresh( + credentials=credentials, + audience=AUDIENCE, + subject_token=TEXT_FILE_SUBJECT_TOKEN, + subject_token_type=SUBJECT_TOKEN_TYPE, + token_url=TOKEN_URL, + basic_auth_encoding=None, + quota_project_id=None, + used_scopes=SCOPES, + scopes=SCOPES, + default_scopes=None, + )