From ef6d872628974270aea2c2d62d408d3a301335d6 Mon Sep 17 00:00:00 2001 From: AniBadde Date: Mon, 27 Jul 2020 20:53:13 -0500 Subject: [PATCH 01/20] feat: asyncio http request logic and asynchronous credentials log c --- google/auth/credentials_async.py | 168 ++++++++++++ google/auth/transport/aiohttp_req.py | 297 ++++++++++++++++++++++ noxfile.py | 54 +++- tests_async/test_credentials.py | 183 +++++++++++++ tests_async/transport/__init__.py | 0 tests_async/transport/async_compliance.py | 136 ++++++++++ tests_async/transport/test_aiohttp_req.py | 163 ++++++++++++ 7 files changed, 998 insertions(+), 3 deletions(-) create mode 100644 google/auth/credentials_async.py create mode 100644 google/auth/transport/aiohttp_req.py create mode 100644 tests_async/test_credentials.py create mode 100644 tests_async/transport/__init__.py create mode 100644 tests_async/transport/async_compliance.py create mode 100644 tests_async/transport/test_aiohttp_req.py diff --git a/google/auth/credentials_async.py b/google/auth/credentials_async.py new file mode 100644 index 000000000..a131cc44b --- /dev/null +++ b/google/auth/credentials_async.py @@ -0,0 +1,168 @@ +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +"""Interfaces for credentials.""" + +import abc + +import six + +from google.auth import credentials + + +@six.add_metaclass(abc.ABCMeta) +class Credentials(credentials.Credentials): + """Async inherited credentials class from google.auth.credentials. + The added functionality is the before_request call which requires + async/await syntax. + All credentials have a :attr:`token` that is used for authentication and + may also optionally set an :attr:`expiry` to indicate when the token will + no longer be valid. + + Most credentials will be :attr:`invalid` until :meth:`refresh` is called. + Credentials can do this automatically before the first HTTP request in + :meth:`before_request`. + + Although the token and expiration will change as the credentials are + :meth:`refreshed ` and used, credentials should be considered + immutable. Various credentials will accept configuration such as private + keys, scopes, and other options. These options are not changeable after + construction. Some classes will provide mechanisms to copy the credentials + with modifications such as :meth:`ScopedCredentials.with_scopes`. + """ + + async def before_request(self, request, method, url, headers): + """Performs credential-specific before request logic. + + Refreshes the credentials if necessary, then calls :meth:`apply` to + apply the token to the authentication header. + + Args: + request (google.auth.transport.Request): The object used to make + HTTP requests. + method (str): The request's HTTP method or the RPC method being + invoked. + url (str): The request's URI or the RPC service's URI. + headers (Mapping): The request's headers. + """ + # pylint: disable=unused-argument + # (Subclasses may use these arguments to ascertain information about + # the http request.) + + if not self.valid: + self.refresh(request) + self.apply(headers) + + +class AnonymousCredentials(credentials.AnonymousCredentials, Credentials): + """Credentials that do not provide any authentication information. + + These are useful in the case of services that support anonymous access or + local service emulators that do not use credentials. This class inherits + from the sync anonymous credentials file, but is kept if async credentials + is initialized and we would like anonymous credentials. + """ + + +@six.add_metaclass(abc.ABCMeta) +class ReadOnlyScoped(credentials.ReadOnlyScoped): + """Interface for credentials whose scopes can be queried. + + OAuth 2.0-based credentials allow limiting access using scopes as described + in `RFC6749 Section 3.3`_. + If a credential class implements this interface then the credentials either + use scopes in their implementation. + + Some credentials require scopes in order to obtain a token. You can check + if scoping is necessary with :attr:`requires_scopes`:: + + if credentials.requires_scopes: + # Scoping is required. + credentials = credentials_async.with_scopes(scopes=['one', 'two']) + + Credentials that require scopes must either be constructed with scopes:: + + credentials = SomeScopedCredentials(scopes=['one', 'two']) + + Or must copy an existing instance using :meth:`with_scopes`:: + + scoped_credentials = credentials_async.with_scopes(scopes=['one', 'two']) + + Some credentials have scopes but do not allow or require scopes to be set, + these credentials can be used as-is. + + .. _RFC6749 Section 3.3: https://tools.ietf.org/html/rfc6749#section-3.3 + """ + + +class Scoped(credentials.Scoped): + """Interface for credentials whose scopes can be replaced while copying. + + OAuth 2.0-based credentials allow limiting access using scopes as described + in `RFC6749 Section 3.3`_. + If a credential class implements this interface then the credentials either + use scopes in their implementation. + + Some credentials require scopes in order to obtain a token. You can check + if scoping is necessary with :attr:`requires_scopes`:: + + if credentials.requires_scopes: + # Scoping is required. + credentials = credentials_async.create_scoped(['one', 'two']) + + Credentials that require scopes must either be constructed with scopes:: + + credentials = SomeScopedCredentials(scopes=['one', 'two']) + + Or must copy an existing instance using :meth:`with_scopes`:: + + scoped_credentials = credentials.with_scopes(scopes=['one', 'two']) + + Some credentials have scopes but do not allow or require scopes to be set, + these credentials can be used as-is. + + .. _RFC6749 Section 3.3: https://tools.ietf.org/html/rfc6749#section-3.3 + """ + + +def with_scopes_if_required(credentials, scopes): + """Creates a copy of the credentials with scopes if scoping is required. + + This helper function is useful when you do not know (or care to know) the + specific type of credentials you are using (such as when you use + :func:`google.auth.default`). This function will call + :meth:`Scoped.with_scopes` if the credentials are scoped credentials and if + the credentials require scoping. Otherwise, it will return the credentials + as-is. + + Args: + credentials (google.auth.credentials.Credentials): The credentials to + scope if necessary. + scopes (Sequence[str]): The list of scopes to use. + + Returns: + google.auth.credentials_async.Credentials: Either a new set of scoped + credentials, or the passed in credentials instance if no scoping + was required. + """ + if isinstance(credentials, Scoped) and credentials.requires_scopes: + return credentials.with_scopes(scopes) + else: + return credentials + + +@six.add_metaclass(abc.ABCMeta) +class Signing(credentials.Signing): + """Interface for credentials that can cryptographically sign messages.""" diff --git a/google/auth/transport/aiohttp_req.py b/google/auth/transport/aiohttp_req.py new file mode 100644 index 000000000..cf3f7abe1 --- /dev/null +++ b/google/auth/transport/aiohttp_req.py @@ -0,0 +1,297 @@ +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Transport adapter for Async HTTP (aiohttp).""" + +from __future__ import absolute_import + +import asyncio +import functools +import logging + + +import aiohttp +import six + + +from google.auth import exceptions +from google.auth import transport +from google.auth.transport import requests + + +_OAUTH_SCOPES = [ + "https://www.googleapis.com/auth/appengine.apis", + "https://www.googleapis.com/auth/userinfo.email", +] + +_LOGGER = logging.getLogger(__name__) + +# Timeout can be re-defined depending on async requirement. Currently made 60s more than +# sync timeout. +_DEFAULT_TIMEOUT = 180 # in seconds + + +class _Response(transport.Response): + """ + Requests transport response adapter. + + Args: + response (requests.Response): The raw Requests response. + """ + + def __init__(self, response): + self.response = response + + @property + def status(self): + return self.response.status + + @property + def headers(self): + return self.response.headers + + @property + def data(self): + return self.response.content + + +class Request(transport.Request): + """Requests request adapter. + + This class is used internally for making requests using asyncio transports + in a consistent way. If you use :class:`AuthorizedSession` you do not need + to construct or use this class directly. + + This class can be useful if you want to manually refresh a + :class:`~google.auth.credentials.Credentials` instance:: + + import google.auth.transport.aiohttp_req + import aiohttp + + request = google.auth.transport.aiohttp_req.Request() + + credentials.refresh(request) + + Args: + session (aiohttp.ClientSession): An instance :class: aiohttp.ClientSession used + to make HTTP requests. If not specified, a session will be created. + + .. automethod:: __call__ + """ + + def __init__(self, session=None): + """ + self.session = None + if not session: + session = aiohttp.ClientSession() + """ + self.session = None + + async def __call__( + self, + url, + method="GET", + body=None, + headers=None, + timeout=_DEFAULT_TIMEOUT, + **kwargs + ): + """ + Make an HTTP request using aiohttp. + + Args: + url (str): The URI to be requested. + method (str): The HTTP method to use for the request. Defaults + to 'GET'. + body (bytes): The payload / body in HTTP request. + headers (Mapping[str, str]): Request headers. + timeout (Optional[int]): The number of seconds to wait for a + response from the server. If not specified or if None, the + requests default timeout will be used. + kwargs: Additional arguments passed through to the underlying + requests :meth:`~requests.Session.request` method. + + Returns: + google.auth.transport.Response: The HTTP response. + + Raises: + google.auth.exceptions.TransportError: If any exception occurred. + """ + + try: + if self.session is None: # pragma: NO COVER + self.session = aiohttp.ClientSession() # pragma: NO COVER + _LOGGER.debug("Making request: %s %s", method, url) + response = await self.session.request( + method, url, data=body, headers=headers, timeout=timeout, **kwargs + ) + return _Response(response) + + except aiohttp.ClientError as caught_exc: + new_exc = exceptions.TransportError(caught_exc) + six.raise_from(new_exc, caught_exc) + + except asyncio.TimeoutError as caught_exc: + new_exc = exceptions.TransportError(caught_exc) + six.raise_from(new_exc, caught_exc) + + +class AuthorizedSession(aiohttp.ClientSession): + """This is an async implementation of the Authorized Session class. We utilize an + aiohttp transport instance, and the interface mirrors the google.auth.transport.requests + Authorized Session class, except for the change in the transport used in the async use case. + + A Requests Session class with credentials. + + This class is used to perform requests to API endpoints that require + authorization:: + + import google.auth.transport.aiohttp_req + + async with aiohttp_req.AuthorizedSession(credentials) as authed_session: + response = await authed_session.request( + 'GET', 'https://www.googleapis.com/storage/v1/b') + + The underlying :meth:`request` implementation handles adding the + credentials' headers to the request and refreshing credentials as needed. + + Args: + credentials (google.auth.credentials_async.Credentials): The credentials to + add to the request. + refresh_status_codes (Sequence[int]): Which HTTP status codes indicate + that credentials should be refreshed and the request should be + retried. + max_refresh_attempts (int): The maximum number of times to attempt to + refresh the credentials and retry the request. + refresh_timeout (Optional[int]): The timeout value in seconds for + credential refresh HTTP requests. + auth_request (google.auth.transport.aiohttp_req.Request): + (Optional) An instance of + :class:`~google.auth.transport.aiohttp_req.Request` used when + refreshing credentials. If not passed, + an instance of :class:`~google.auth.transport.aiohttp_req.Request` + is created. + """ + + def __init__( + self, + credentials, + refresh_status_codes=transport.DEFAULT_REFRESH_STATUS_CODES, + max_refresh_attempts=transport.DEFAULT_MAX_REFRESH_ATTEMPTS, + refresh_timeout=None, + auth_request=None, + ): + super(AuthorizedSession, self).__init__() + self.credentials = credentials + self._refresh_status_codes = refresh_status_codes + self._max_refresh_attempts = max_refresh_attempts + self._refresh_timeout = refresh_timeout + self._is_mtls = False + self._auth_request = auth_request + self._auth_request_session = None + self._loop = asyncio.get_event_loop() + self._refresh_lock = asyncio.Lock() + + async def request( + self, + method, + url, + data=None, + headers=None, + max_allowed_time=None, + timeout=_DEFAULT_TIMEOUT, + **kwargs + ): + + if self._auth_request is None: + self._auth_request_session = aiohttp.ClientSession() + auth_request = Request(self._auth_request_session) + self._auth_request = auth_request + + # Use a kwarg for this instead of an attribute to maintain + # thread-safety. + _credential_refresh_attempt = kwargs.pop("_credential_refresh_attempt", 0) + # Make a copy of the headers. They will be modified by the credentials + # and we want to pass the original headers if we recurse. + request_headers = headers.copy() if headers is not None else {} + + # Do not apply the timeout unconditionally in order to not override the + # _auth_request's default timeout. + auth_request = ( + self._auth_request + if timeout is None + else functools.partial(self._auth_request, timeout=timeout) + ) + + remaining_time = max_allowed_time + + with requests.TimeoutGuard(remaining_time, asyncio.TimeoutError) as guard: + await self.credentials.before_request( + auth_request, method, url, request_headers + ) + + with requests.TimeoutGuard(remaining_time, asyncio.TimeoutError) as guard: + response = await super(AuthorizedSession, self).request( + method, + url, + data=data, + headers=request_headers, + timeout=timeout, + **kwargs + ) + + remaining_time = guard.remaining_timeout + + if ( + response.status in self._refresh_status_codes + and _credential_refresh_attempt < self._max_refresh_attempts + ): + + _LOGGER.info( + "Refreshing credentials due to a %s response. Attempt %s/%s.", + response.status, + _credential_refresh_attempt + 1, + self._max_refresh_attempts, + ) + + # Do not apply the timeout unconditionally in order to not override the + # _auth_request's default timeout. + auth_request = ( + self._auth_request + if timeout is None + else functools.partial(self._auth_request, timeout=timeout) + ) + + with requests.TimeoutGuard(remaining_time, asyncio.TimeoutError) as guard: + async with self._refresh_lock: + await self._loop.run_in_executor( + None, self.credentials.refresh, auth_request + ) + + remaining_time = guard.remaining_timeout + + return await self.request( + method, + url, + data=data, + headers=headers, + max_allowed_time=remaining_time, + timeout=timeout, + _credential_refresh_attempt=_credential_refresh_attempt + 1, + **kwargs + ) + + await self._auth_request_session.close() + + return response diff --git a/noxfile.py b/noxfile.py index c39f27c47..0bd9544fa 100644 --- a/noxfile.py +++ b/noxfile.py @@ -28,9 +28,35 @@ "cryptography", "responses", "grpcio", + "pytest-asyncio", + "aioresponses", ] + +TEST_DEPENDENCIES2 = [ + "flask", + "freezegun", + "mock", + "oauth2client", + "pyopenssl", + "pytest", + "pytest-cov", + "pytest-localserver", + "requests", + "urllib3", + "cryptography", + "responses", + "grpcio", +] + BLACK_VERSION = "black==19.3b0" -BLACK_PATHS = ["google", "tests", "noxfile.py", "setup.py", "docs/conf.py"] +BLACK_PATHS = [ + "google", + "tests", + "tests_async", + "noxfile.py", + "setup.py", + "docs/conf.py", +] @nox.session(python="3.7") @@ -44,6 +70,7 @@ def lint(session): "--application-import-names=google,tests,system_tests", "google", "tests", + "tests_async", ) session.run( "python", "setup.py", "check", "--metadata", "--restructuredtext", "--strict" @@ -64,10 +91,24 @@ def blacken(session): session.run("black", *BLACK_PATHS) -@nox.session(python=["2.7", "3.5", "3.6", "3.7", "3.8"]) +@nox.session(python=["3.6", "3.7", "3.8"]) def unit(session): session.install(*TEST_DEPENDENCIES) session.install(".") + session.run( + "pytest", + "--cov=google.auth", + "--cov=google.oauth2", + "--cov=tests", + "tests", + "tests_async", + ) + + +@nox.session(python=["2.7", "3.5"]) +def unit_prev_versions(session): + session.install(*TEST_DEPENDENCIES2) + session.install(".") session.run( "pytest", "--cov=google.auth", "--cov=google.oauth2", "--cov=tests", "tests" ) @@ -82,8 +123,10 @@ def cover(session): "--cov=google.auth", "--cov=google.oauth2", "--cov=tests", + "--cov=tests_async", "--cov-report=", "tests", + "tests_async", ) session.run("coverage", "report", "--show-missing", "--fail-under=100") @@ -117,5 +160,10 @@ def pypy(session): session.install(*TEST_DEPENDENCIES) session.install(".") session.run( - "pytest", "--cov=google.auth", "--cov=google.oauth2", "--cov=tests", "tests" + "pytest", + "--cov=google.auth", + "--cov=google.oauth2", + "--cov=tests", + "tests", + "tests_async", ) diff --git a/tests_async/test_credentials.py b/tests_async/test_credentials.py new file mode 100644 index 000000000..377f9a7e2 --- /dev/null +++ b/tests_async/test_credentials.py @@ -0,0 +1,183 @@ +# Copyright 2016 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import datetime + +import pytest + +from google.auth import _helpers +from google.auth import credentials_async as credentials + + +class CredentialsImpl(credentials.Credentials): + def refresh(self, request): + self.token = request + + def with_quota_project(self, quota_project_id): + raise NotImplementedError() + + +def test_credentials_constructor(): + credentials = CredentialsImpl() + assert not credentials.token + assert not credentials.expiry + assert not credentials.expired + assert not credentials.valid + + +def test_expired_and_valid(): + credentials = CredentialsImpl() + credentials.token = "token" + + assert credentials.valid + assert not credentials.expired + + # Set the expiration to one second more than now plus the clock skew + # accomodation. These credentials should be valid. + credentials.expiry = ( + datetime.datetime.utcnow() + _helpers.CLOCK_SKEW + datetime.timedelta(seconds=1) + ) + + assert credentials.valid + assert not credentials.expired + + # Set the credentials expiration to now. Because of the clock skew + # accomodation, these credentials should report as expired. + credentials.expiry = datetime.datetime.utcnow() + + assert not credentials.valid + assert credentials.expired + + +@pytest.mark.asyncio +async def test_before_request(): + credentials = CredentialsImpl() + request = "token" + headers = {} + + # First call should call refresh, setting the token. + await credentials.before_request(request, "http://example.com", "GET", headers) + assert credentials.valid + assert credentials.token == "token" + assert headers["authorization"] == "Bearer token" + + request = "token2" + headers = {} + + # Second call shouldn't call refresh. + credentials.before_request(request, "http://example.com", "GET", headers) + + assert credentials.valid + assert credentials.token == "token" + + +def test_anonymous_credentials_ctor(): + anon = credentials.AnonymousCredentials() + + assert anon.token is None + assert anon.expiry is None + assert not anon.expired + assert anon.valid + + +def test_anonymous_credentials_refresh(): + anon = credentials.AnonymousCredentials() + + request = object() + with pytest.raises(ValueError): + anon.refresh(request) + + +def test_anonymous_credentials_apply_default(): + anon = credentials.AnonymousCredentials() + headers = {} + anon.apply(headers) + assert headers == {} + with pytest.raises(ValueError): + anon.apply(headers, token="TOKEN") + + +def test_anonymous_credentials_before_request(): + anon = credentials.AnonymousCredentials() + request = object() + method = "GET" + url = "https://example.com/api/endpoint" + headers = {} + anon.before_request(request, method, url, headers) + assert headers == {} + + +def test_anonymous_credentials_with_quota_project(): + with pytest.raises(ValueError): + anon = credentials.AnonymousCredentials() + anon.with_quota_project("project-foo") + + +class ReadOnlyScopedCredentialsImpl(credentials.ReadOnlyScoped, CredentialsImpl): + @property + def requires_scopes(self): + return super(ReadOnlyScopedCredentialsImpl, self).requires_scopes + + +def test_readonly_scoped_credentials_constructor(): + credentials = ReadOnlyScopedCredentialsImpl() + assert credentials._scopes is None + + +def test_readonly_scoped_credentials_scopes(): + credentials = ReadOnlyScopedCredentialsImpl() + credentials._scopes = ["one", "two"] + assert credentials.scopes == ["one", "two"] + assert credentials.has_scopes(["one"]) + assert credentials.has_scopes(["two"]) + assert credentials.has_scopes(["one", "two"]) + assert not credentials.has_scopes(["three"]) + + +def test_readonly_scoped_credentials_requires_scopes(): + credentials = ReadOnlyScopedCredentialsImpl() + assert not credentials.requires_scopes + + +class RequiresScopedCredentialsImpl(credentials.Scoped, CredentialsImpl): + def __init__(self, scopes=None): + super(RequiresScopedCredentialsImpl, self).__init__() + self._scopes = scopes + + @property + def requires_scopes(self): + return not self.scopes + + def with_scopes(self, scopes): + return RequiresScopedCredentialsImpl(scopes=scopes) + + +def test_create_scoped_if_required_scoped(): + unscoped_credentials = RequiresScopedCredentialsImpl() + scoped_credentials = credentials.with_scopes_if_required( + unscoped_credentials, ["one", "two"] + ) + + assert scoped_credentials is not unscoped_credentials + assert not scoped_credentials.requires_scopes + assert scoped_credentials.has_scopes(["one", "two"]) + + +def test_create_scoped_if_required_not_scopes(): + unscoped_credentials = CredentialsImpl() + scoped_credentials = credentials.with_scopes_if_required( + unscoped_credentials, ["one", "two"] + ) + + assert scoped_credentials is unscoped_credentials diff --git a/tests_async/transport/__init__.py b/tests_async/transport/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests_async/transport/async_compliance.py b/tests_async/transport/async_compliance.py new file mode 100644 index 000000000..0f204bd56 --- /dev/null +++ b/tests_async/transport/async_compliance.py @@ -0,0 +1,136 @@ +# Copyright 2016 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import time + +import flask +import pytest +from pytest_localserver.http import WSGIServer +from six.moves import http_client + + +from google.auth import exceptions + +# .invalid will never resolve, see https://tools.ietf.org/html/rfc2606 +NXDOMAIN = "test.invalid" + + +class RequestResponseTests(object): + @pytest.fixture(scope="module") + def server(self): + """Provides a test HTTP server. + + The test server is automatically created before + a test and destroyed at the end. The server is serving a test + application that can be used to verify requests. + """ + app = flask.Flask(__name__) + app.debug = True + + # pylint: disable=unused-variable + # (pylint thinks the flask routes are unusued.) + @app.route("/basic") + def index(): + header_value = flask.request.headers.get("x-test-header", "value") + headers = {"X-Test-Header": header_value} + return "Basic Content", http_client.OK, headers + + @app.route("/server_error") + def server_error(): + return "Error", http_client.INTERNAL_SERVER_ERROR + + @app.route("/wait") + def wait(): + time.sleep(3) + return "Waited" + + # pylint: enable=unused-variable + + server = WSGIServer(application=app.wsgi_app) + server.start() + yield server + server.stop() + + @pytest.mark.asyncio + async def test_request_basic(self, server): + request = self.make_request() + response = await request(url=server.url + "/basic", method="GET") + assert response.status == http_client.OK + assert response.headers["x-test-header"] == "value" + + # Use 13 as this is the length of the data written into the stream. + + data = await response.data.read(13) + assert data == b"Basic Content" + + @pytest.mark.asyncio + async def test_request_basic_with_http(self, server): + request = self.make_with_parameter_request() + response = await request(url=server.url + "/basic", method="GET") + assert response.status == http_client.OK + assert response.headers["x-test-header"] == "value" + + # Use 13 as this is the length of the data written into the stream. + + data = await response.data.read(13) + assert data == b"Basic Content" + + @pytest.mark.asyncio + async def test_request_with_timeout_success(self, server): + request = self.make_request() + response = await request(url=server.url + "/basic", method="GET", timeout=2) + + assert response.status == http_client.OK + assert response.headers["x-test-header"] == "value" + + data = await response.data.read(13) + assert data == b"Basic Content" + + @pytest.mark.asyncio + async def test_request_with_timeout_failure(self, server): + request = self.make_request() + + with pytest.raises(exceptions.TransportError): + await request(url=server.url + "/wait", method="GET", timeout=1) + + @pytest.mark.asyncio + async def test_request_headers(self, server): + request = self.make_request() + response = await request( + url=server.url + "/basic", + method="GET", + headers={"x-test-header": "hello world"}, + ) + + assert response.status == http_client.OK + assert response.headers["x-test-header"] == "hello world" + + data = await response.data.read(13) + assert data == b"Basic Content" + + @pytest.mark.asyncio + async def test_request_error(self, server): + request = self.make_request() + + response = await request(url=server.url + "/server_error", method="GET") + assert response.status == http_client.INTERNAL_SERVER_ERROR + data = await response.data.read(5) + assert data == b"Error" + + @pytest.mark.asyncio + async def test_connection_error(self): + request = self.make_request() + + with pytest.raises(exceptions.TransportError): + await request(url="http://{}".format(NXDOMAIN), method="GET") diff --git a/tests_async/transport/test_aiohttp_req.py b/tests_async/transport/test_aiohttp_req.py new file mode 100644 index 000000000..1643eaa06 --- /dev/null +++ b/tests_async/transport/test_aiohttp_req.py @@ -0,0 +1,163 @@ +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import aiohttp +from aioresponses import aioresponses +import mock +import pytest +from tests_async.transport import async_compliance + +import google.auth.credentials_async +from google.auth.transport import aiohttp_req +import google.auth.transport._mtls_helper + + +class TestRequestResponse(async_compliance.RequestResponseTests): + def make_request(self): + return aiohttp_req.Request() + + def make_with_parameter_request(self): + http = mock.create_autospec(aiohttp.ClientSession, instance=True) + return aiohttp_req.Request(http) + + def test_timeout(self): + http = mock.create_autospec(aiohttp.ClientSession, instance=True) + request = google.auth.transport.aiohttp_req.Request(http) + request(url="http://example.com", method="GET", timeout=5) + + +class CredentialsStub(google.auth.credentials_async.Credentials): + def __init__(self, token="token"): + super(CredentialsStub, self).__init__() + self.token = token + + def apply(self, headers, token=None): + headers["authorization"] = self.token + + def refresh(self, request): + self.token += "1" + + +class TestAuthorizedSession(object): + TEST_URL = "http://example.com/" + method = "GET" + + def test_constructor(self): + authed_session = google.auth.transport.aiohttp_req.AuthorizedSession( + mock.sentinel.credentials + ) + assert authed_session.credentials == mock.sentinel.credentials + + def test_constructor_with_auth_request(self): + http = mock.create_autospec(aiohttp.ClientSession) + auth_request = google.auth.transport.aiohttp_req.Request(http) + + authed_session = google.auth.transport.aiohttp_req.AuthorizedSession( + mock.sentinel.credentials, auth_request=auth_request + ) + + assert authed_session._auth_request == auth_request + + @pytest.mark.asyncio + async def test_request(self): + with aioresponses() as mocked: + credentials = mock.Mock(wraps=CredentialsStub()) + + mocked.get(self.TEST_URL, status=200, body="test") + session = aiohttp_req.AuthorizedSession(credentials) + resp = await session.request("GET", "http://example.com/") + + assert resp.status == 200 + assert "test" == await resp.text() + + await session.close() + + @pytest.mark.asyncio + async def test_ctx(self): + with aioresponses() as mocked: + credentials = mock.Mock(wraps=CredentialsStub()) + mocked.get("http://test.example.com", payload=dict(foo="bar")) + session = aiohttp_req.AuthorizedSession(credentials) + resp = await session.request("GET", "http://test.example.com") + data = await resp.json() + + assert dict(foo="bar") == data + + await session.close() + + @pytest.mark.asyncio + async def test_http_headers(self): + with aioresponses() as mocked: + credentials = mock.Mock(wraps=CredentialsStub()) + mocked.post( + "http://example.com", + payload=dict(), + headers=dict(connection="keep-alive"), + ) + + session = aiohttp_req.AuthorizedSession(credentials) + resp = await session.request("POST", "http://example.com") + + assert resp.headers["Connection"] == "keep-alive" + + await session.close() + + @pytest.mark.asyncio + async def test_regexp_example(self): + with aioresponses() as mocked: + credentials = mock.Mock(wraps=CredentialsStub()) + mocked.get("http://example.com", status=500) + mocked.get("http://example.com", status=200) + + session1 = aiohttp_req.AuthorizedSession(credentials) + + resp1 = await session1.request("GET", "http://example.com") + session2 = aiohttp_req.AuthorizedSession(credentials) + resp2 = await session2.request("GET", "http://example.com") + + assert resp1.status == 500 + assert resp2.status == 200 + + await session1.close() + await session2.close() + + @pytest.mark.asyncio + async def test_request_no_refresh(self): + credentials = mock.Mock(wraps=CredentialsStub()) + with aioresponses() as mocked: + mocked.get("http://example.com", status=200) + authed_session = google.auth.transport.aiohttp_req.AuthorizedSession( + credentials + ) + response = await authed_session.request("GET", "http://example.com") + assert response.status == 200 + assert credentials.before_request.called + assert not credentials.refresh.called + + await authed_session.close() + + @pytest.mark.asyncio + async def test_request_refresh(self): + credentials = mock.Mock(wraps=CredentialsStub()) + with aioresponses() as mocked: + mocked.get("http://example.com", status=401) + mocked.get("http://example.com", status=200) + authed_session = google.auth.transport.aiohttp_req.AuthorizedSession( + credentials + ) + response = await authed_session.request("GET", "http://example.com") + assert credentials.refresh.called + assert response.status == 200 + + await authed_session.close() From 0a63b3c0e2a0f3aa71e2af8ac8f0cca4343e6b21 Mon Sep 17 00:00:00 2001 From: AniBadde Date: Mon, 27 Jul 2020 21:31:24 -0500 Subject: [PATCH 02/20] feat: all asynchronous credentials types implemented and with tests --- google/auth/__init__.py | 3 +- google/auth/_default_async.py | 333 ++++++++++ google/auth/_oauth2client_async.py | 171 +++++ google/auth/credentials_async.py | 6 +- google/auth/jwt_async.py | 362 +++++++++++ google/oauth2/_client_async.py | 273 ++++++++ google/oauth2/credentials_async.py | 107 ++++ google/oauth2/service_account_async.py | 134 ++++ tests_async/__init__.py | 0 tests_async/conftest.py | 51 ++ tests_async/oauth2/test__client_async.py | 313 ++++++++++ tests_async/oauth2/test_credentials_async.py | 482 +++++++++++++++ .../oauth2/test_service_account_async.py | 367 +++++++++++ tests_async/test__default.py | 484 +++++++++++++++ tests_async/test__oauth2client.py | 170 +++++ tests_async/test_jwt.py | 583 ++++++++++++++++++ 16 files changed, 3837 insertions(+), 2 deletions(-) create mode 100644 google/auth/_default_async.py create mode 100644 google/auth/_oauth2client_async.py create mode 100644 google/auth/jwt_async.py create mode 100644 google/oauth2/_client_async.py create mode 100644 google/oauth2/credentials_async.py create mode 100644 google/oauth2/service_account_async.py create mode 100644 tests_async/__init__.py create mode 100644 tests_async/conftest.py create mode 100644 tests_async/oauth2/test__client_async.py create mode 100644 tests_async/oauth2/test_credentials_async.py create mode 100644 tests_async/oauth2/test_service_account_async.py create mode 100644 tests_async/test__default.py create mode 100644 tests_async/test__oauth2client.py create mode 100644 tests_async/test_jwt.py diff --git a/google/auth/__init__.py b/google/auth/__init__.py index 5ca20a362..b03add240 100644 --- a/google/auth/__init__.py +++ b/google/auth/__init__.py @@ -17,9 +17,10 @@ import logging from google.auth._default import default, load_credentials_from_file +from google.auth._default_async import default_async -__all__ = ["default", "load_credentials_from_file"] +__all__ = ["default", "load_credentials_from_file", "default_async"] # Set default logging handler to avoid "No handler found" warnings. diff --git a/google/auth/_default_async.py b/google/auth/_default_async.py new file mode 100644 index 000000000..b901aa0a7 --- /dev/null +++ b/google/auth/_default_async.py @@ -0,0 +1,333 @@ +# Copyright 2015 Google Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Application default credentials. + +Implements application default credentials and project ID detection. +""" + +import io +import json +import logging +import os +import warnings + +import six + +from google.auth import environment_vars +from google.auth import exceptions +import google.auth.transport._http_client + +_LOGGER = logging.getLogger(__name__) + +# Valid types accepted for file-based credentials. +_AUTHORIZED_USER_TYPE = "authorized_user" +_SERVICE_ACCOUNT_TYPE = "service_account" +_VALID_TYPES = (_AUTHORIZED_USER_TYPE, _SERVICE_ACCOUNT_TYPE) + +# Help message when no credentials can be found. +_HELP_MESSAGE = """\ +Could not automatically determine credentials. Please set {env} or \ +explicitly create credentials and re-run the application. For more \ +information, please see \ +https://cloud.google.com/docs/authentication/getting-started +""".format( + env=environment_vars.CREDENTIALS +).strip() + +# Warning when using Cloud SDK user credentials +_CLOUD_SDK_CREDENTIALS_WARNING = """\ +Your application has authenticated using end user credentials from Google \ +Cloud SDK without a quota project. You might receive a "quota exceeded" \ +or "API not enabled" error. We recommend you rerun \ +`gcloud auth application-default login` and make sure a quota project is \ +added. Or you can use service accounts instead. For more information \ +about service accounts, see https://cloud.google.com/docs/authentication/""" + + +def _warn_about_problematic_credentials(credentials): + """Determines if the credentials are problematic. + + Credentials from the Cloud SDK that are associated with Cloud SDK's project + are problematic because they may not have APIs enabled and have limited + quota. If this is the case, warn about it. + """ + from google.auth import _cloud_sdk + + if credentials.client_id == _cloud_sdk.CLOUD_SDK_CLIENT_ID: + warnings.warn(_CLOUD_SDK_CREDENTIALS_WARNING) + + +def load_credentials_from_file(filename, scopes=None, quota_project_id=None): + """Loads Google credentials from a file. + + The credentials file must be a service account key or stored authorized + user credentials. + + Args: + filename (str): The full path to the credentials file. + scopes (Optional[Sequence[str]]): The list of scopes for the credentials. If + specified, the credentials will automatically be scoped if + necessary + quota_project_id (Optional[str]): The project ID used for + quota and billing. + + Returns: + Tuple[google.auth.credentials.Credentials, Optional[str]]: Loaded + credentials and the project ID. Authorized user credentials do not + have the project ID information. + + Raises: + google.auth.exceptions.DefaultCredentialsError: if the file is in the + wrong format or is missing. + """ + if not os.path.exists(filename): + raise exceptions.DefaultCredentialsError( + "File {} was not found.".format(filename) + ) + + with io.open(filename, "r") as file_obj: + try: + info = json.load(file_obj) + except ValueError as caught_exc: + new_exc = exceptions.DefaultCredentialsError( + "File {} is not a valid json file.".format(filename), caught_exc + ) + six.raise_from(new_exc, caught_exc) + + # The type key should indicate that the file is either a service account + # credentials file or an authorized user credentials file. + credential_type = info.get("type") + + if credential_type == _AUTHORIZED_USER_TYPE: + from google.oauth2 import credentials_async as credentials + + try: + credentials = credentials.Credentials.from_authorized_user_info( + info, scopes=scopes + ).with_quota_project(quota_project_id) + except ValueError as caught_exc: + msg = "Failed to load authorized user credentials from {}".format(filename) + new_exc = exceptions.DefaultCredentialsError(msg, caught_exc) + six.raise_from(new_exc, caught_exc) + if not credentials.quota_project_id: + _warn_about_problematic_credentials(credentials) + return credentials, None + + elif credential_type == _SERVICE_ACCOUNT_TYPE: + from google.oauth2 import service_account_async as service_account + + try: + credentials = service_account.Credentials.from_service_account_info( + info, scopes=scopes + ).with_quota_project(quota_project_id) + except ValueError as caught_exc: + msg = "Failed to load service account credentials from {}".format(filename) + new_exc = exceptions.DefaultCredentialsError(msg, caught_exc) + six.raise_from(new_exc, caught_exc) + return credentials, info.get("project_id") + + else: + raise exceptions.DefaultCredentialsError( + "The file {file} does not have a valid type. " + "Type is {type}, expected one of {valid_types}.".format( + file=filename, type=credential_type, valid_types=_VALID_TYPES + ) + ) + + +def _get_gcloud_sdk_credentials(): + """Gets the credentials and project ID from the Cloud SDK.""" + from google.auth import _cloud_sdk + + # Check if application default credentials exist. + credentials_filename = _cloud_sdk.get_application_default_credentials_path() + + if not os.path.isfile(credentials_filename): + return None, None + + credentials, project_id = load_credentials_from_file(credentials_filename) + + if not project_id: + project_id = _cloud_sdk.get_project_id() + + return credentials, project_id + + +def _get_explicit_environ_credentials(): + """Gets credentials from the GOOGLE_APPLICATION_CREDENTIALS environment + variable.""" + explicit_file = os.environ.get(environment_vars.CREDENTIALS) + + if explicit_file is not None: + credentials, project_id = load_credentials_from_file( + os.environ[environment_vars.CREDENTIALS] + ) + + return credentials, project_id + + else: + return None, None + + +def _get_gae_credentials(): + """Gets Google App Engine App Identity credentials and project ID.""" + # While this library is normally bundled with app_engine, there are + # some cases where it's not available, so we tolerate ImportError. + try: + import google.auth.app_engine as app_engine + except ImportError: + return None, None + + try: + credentials = app_engine.Credentials() + project_id = app_engine.get_project_id() + return credentials, project_id + except EnvironmentError: + return None, None + + +def _get_gce_credentials(request=None): + """Gets credentials and project ID from the GCE Metadata Service.""" + # Ping requires a transport, but we want application default credentials + # to require no arguments. So, we'll use the _http_client transport which + # uses http.client. This is only acceptable because the metadata server + # doesn't do SSL and never requires proxies. + + # While this library is normally bundled with compute_engine, there are + # some cases where it's not available, so we tolerate ImportError. + try: + from google.auth import compute_engine + from google.auth.compute_engine import _metadata + except ImportError: + return None, None + + if request is None: + request = google.auth.transport._http_client.Request() + + if _metadata.ping(request=request): + # Get the project ID. + try: + project_id = _metadata.get_project_id(request=request) + except exceptions.TransportError: + project_id = None + + return compute_engine.Credentials(), project_id + else: + return None, None + + +def default_async(scopes=None, request=None, quota_project_id=None): + """Gets the default credentials for the current environment. + + `Application Default Credentials`_ provides an easy way to obtain + credentials to call Google APIs for server-to-server or local applications. + This function acquires credentials from the environment in the following + order: + + 1. If the environment variable ``GOOGLE_APPLICATION_CREDENTIALS`` is set + to the path of a valid service account JSON private key file, then it is + loaded and returned. The project ID returned is the project ID defined + in the service account file if available (some older files do not + contain project ID information). + 2. If the `Google Cloud SDK`_ is installed and has application default + credentials set they are loaded and returned. + + To enable application default credentials with the Cloud SDK run:: + + gcloud auth application-default login + + If the Cloud SDK has an active project, the project ID is returned. The + active project can be set using:: + + gcloud config set project + + 3. If the application is running in the `App Engine standard environment`_ + then the credentials and project ID from the `App Identity Service`_ + are used. + 4. If the application is running in `Compute Engine`_ or the + `App Engine flexible environment`_ then the credentials and project ID + are obtained from the `Metadata Service`_. + 5. If no credentials are found, + :class:`~google.auth.exceptions.DefaultCredentialsError` will be raised. + + .. _Application Default Credentials: https://developers.google.com\ + /identity/protocols/application-default-credentials + .. _Google Cloud SDK: https://cloud.google.com/sdk + .. _App Engine standard environment: https://cloud.google.com/appengine + .. _App Identity Service: https://cloud.google.com/appengine/docs/python\ + /appidentity/ + .. _Compute Engine: https://cloud.google.com/compute + .. _App Engine flexible environment: https://cloud.google.com\ + /appengine/flexible + .. _Metadata Service: https://cloud.google.com/compute/docs\ + /storing-retrieving-metadata + + Example:: + + import google.auth + + credentials, project_id = google.auth.default() + + Args: + scopes (Sequence[str]): The list of scopes for the credentials. If + specified, the credentials will automatically be scoped if + necessary. + request (google.auth.transport.Request): An object used to make + HTTP requests. This is used to detect whether the application + is running on Compute Engine. If not specified, then it will + use the standard library http client to make requests. + quota_project_id (Optional[str]): The project ID used for + quota and billing. + Returns: + Tuple[~google.auth.credentials.Credentials, Optional[str]]: + the current environment's credentials and project ID. Project ID + may be None, which indicates that the Project ID could not be + ascertained from the environment. + + Raises: + ~google.auth.exceptions.DefaultCredentialsError: + If no credentials were found, or if the credentials found were + invalid. + """ + from google.auth.credentials_async import with_scopes_if_required + + explicit_project_id = os.environ.get( + environment_vars.PROJECT, os.environ.get(environment_vars.LEGACY_PROJECT) + ) + + checkers = ( + _get_explicit_environ_credentials, + _get_gcloud_sdk_credentials, + _get_gae_credentials, + lambda: _get_gce_credentials(request), + ) + + for checker in checkers: + credentials, project_id = checker() + if credentials is not None: + credentials = with_scopes_if_required( + credentials, scopes + ).with_quota_project(quota_project_id) + effective_project_id = explicit_project_id or project_id + if not effective_project_id: + _LOGGER.warning( + "No project ID could be determined. Consider running " + "`gcloud config set project` or setting the %s " + "environment variable", + environment_vars.PROJECT, + ) + return credentials, effective_project_id + + raise exceptions.DefaultCredentialsError(_HELP_MESSAGE) diff --git a/google/auth/_oauth2client_async.py b/google/auth/_oauth2client_async.py new file mode 100644 index 000000000..2913134a4 --- /dev/null +++ b/google/auth/_oauth2client_async.py @@ -0,0 +1,171 @@ +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Helpers for transitioning from oauth2client to google-auth. + +.. warning:: + This module is private as it is intended to assist first-party downstream + clients with the transition from oauth2client to google-auth. +""" + +from __future__ import absolute_import + +import six + +from google.auth import _helpers +import google.auth.app_engine +import google.auth.compute_engine +import google.oauth2.credentials +import google.oauth2.service_account_async + +try: + import oauth2client.client + import oauth2client.contrib.gce + import oauth2client.service_account +except ImportError as caught_exc: + six.raise_from(ImportError("oauth2client is not installed."), caught_exc) + +try: + import oauth2client.contrib.appengine # pytype: disable=import-error + + _HAS_APPENGINE = True +except ImportError: + _HAS_APPENGINE = False + + +_CONVERT_ERROR_TMPL = "Unable to convert {} to a google-auth credentials class." + + +def _convert_oauth2_credentials(credentials): + """Converts to :class:`google.oauth2.credentials_async.Credentials`. + + Args: + credentials (Union[oauth2client.client.OAuth2Credentials, + oauth2client.client.GoogleCredentials]): The credentials to + convert. + + Returns: + google.oauth2.credentials_async.Credentials: The converted credentials. + """ + new_credentials = google.oauth2.credentials.Credentials( + token=credentials.access_token, + refresh_token=credentials.refresh_token, + token_uri=credentials.token_uri, + client_id=credentials.client_id, + client_secret=credentials.client_secret, + scopes=credentials.scopes, + ) + + new_credentials._expires = credentials.token_expiry + + return new_credentials + + +def _convert_service_account_credentials(credentials): + """Converts to :class:`google.oauth2.service_account_async.Credentials`. + + Args: + credentials (Union[ + oauth2client.service_account_async.ServiceAccountCredentials, + oauth2client.service_account_async._JWTAccessCredentials]): The + credentials to convert. + + Returns: + google.oauth2.service_account_async.Credentials: The converted credentials. + """ + info = credentials.serialization_data.copy() + info["token_uri"] = credentials.token_uri + return google.oauth2.service_account_async.Credentials.from_service_account_info( + info + ) + + +def _convert_gce_app_assertion_credentials(credentials): + """Converts to :class:`google.auth.compute_engine.Credentials`. + + Args: + credentials (oauth2client.contrib.gce.AppAssertionCredentials): The + credentials to convert. + + Returns: + google.oauth2.service_account_async.Credentials: The converted credentials. + """ + return google.auth.compute_engine.Credentials( + service_account_email=credentials.service_account_email + ) + + +def _convert_appengine_app_assertion_credentials(credentials): + """Converts to :class:`google.auth.app_engine.Credentials`. + + Args: + credentials (oauth2client.contrib.app_engine.AppAssertionCredentials): + The credentials to convert. + + Returns: + google.oauth2.service_account_async.Credentials: The converted credentials. + """ + # pylint: disable=invalid-name + return google.auth.app_engine.Credentials( + scopes=_helpers.string_to_scopes(credentials.scope), + service_account_id=credentials.service_account_id, + ) + + +_CLASS_CONVERSION_MAP = { + oauth2client.client.OAuth2Credentials: _convert_oauth2_credentials, + oauth2client.client.GoogleCredentials: _convert_oauth2_credentials, + oauth2client.service_account.ServiceAccountCredentials: _convert_service_account_credentials, + oauth2client.service_account._JWTAccessCredentials: _convert_service_account_credentials, + oauth2client.contrib.gce.AppAssertionCredentials: _convert_gce_app_assertion_credentials, +} + +if _HAS_APPENGINE: + _CLASS_CONVERSION_MAP[ + oauth2client.contrib.appengine.AppAssertionCredentials + ] = _convert_appengine_app_assertion_credentials + + +def convert(credentials): + """Convert oauth2client credentials to google-auth credentials. + + This class converts: + + - :class:`oauth2client.client_async.OAuth2Credentials` to + :class:`google.oauth2.credentials_async.Credentials`. + - :class:`oauth2client.client_async.GoogleCredentials` to + :class:`google.oauth2.credentials_async.Credentials`. + - :class:`oauth2client.service_account_async.ServiceAccountCredentials` to + :class:`google.oauth2.service_account_async.Credentials`. + - :class:`oauth2client.service_account_async._JWTAccessCredentials` to + :class:`google.oauth2.service_account_async.Credentials`. + - :class:`oauth2client.contrib.gce.AppAssertionCredentials` to + :class:`google.auth.compute_engine.Credentials`. + - :class:`oauth2client.contrib.appengine.AppAssertionCredentials` to + :class:`google.auth.app_engine.Credentials`. + + Returns: + google.auth.credentials_async.Credentials: The converted credentials. + + Raises: + ValueError: If the credentials could not be converted. + """ + + credentials_class = type(credentials) + + try: + return _CLASS_CONVERSION_MAP[credentials_class](credentials) + except KeyError as caught_exc: + new_exc = ValueError(_CONVERT_ERROR_TMPL.format(credentials_class)) + six.raise_from(new_exc, caught_exc) diff --git a/google/auth/credentials_async.py b/google/auth/credentials_async.py index a131cc44b..5916e45d9 100644 --- a/google/auth/credentials_async.py +++ b/google/auth/credentials_async.py @@ -16,6 +16,7 @@ """Interfaces for credentials.""" import abc +import inspect import six @@ -62,7 +63,10 @@ async def before_request(self, request, method, url, headers): # the http request.) if not self.valid: - self.refresh(request) + if inspect.iscoroutinefunction(self.refresh): + await self.refresh(request) + else: + self.refresh(request) self.apply(headers) diff --git a/google/auth/jwt_async.py b/google/auth/jwt_async.py new file mode 100644 index 000000000..e9782bffb --- /dev/null +++ b/google/auth/jwt_async.py @@ -0,0 +1,362 @@ +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""JSON Web Tokens + +Provides support for creating (encoding) and verifying (decoding) JWTs, +especially JWTs generated and consumed by Google infrastructure. + +See `rfc7519`_ for more details on JWTs. + +To encode a JWT use :func:`encode`:: + + from google.auth import crypt + from google.auth import jwt + + signer = crypt.Signer(private_key) + payload = {'some': 'payload'} + encoded = jwt.encode(signer, payload) + +To decode a JWT and verify claims use :func:`decode`:: + + claims = jwt.decode(encoded, certs=public_certs) + +You can also skip verification:: + + claims = jwt.decode(encoded, verify=False) + +.. _rfc7519: https://tools.ietf.org/html/rfc7519 + +""" + +try: + from collections.abc import Mapping +# Python 2.7 compatibility +except ImportError: # pragma: NO COVER + from collections import Mapping + +import json + +import six + +import google.auth +from google.auth import _helpers +from google.auth import crypt +from google.auth import jwt + +try: + from google.auth.crypt import es256 +except ImportError: # pragma: NO COVER + es256 = None + +_DEFAULT_TOKEN_LIFETIME_SECS = 3600 # 1 hour in seconds +_DEFAULT_MAX_CACHE_SIZE = 10 +_ALGORITHM_TO_VERIFIER_CLASS = {"RS256": crypt.RSAVerifier} +_CRYPTOGRAPHY_BASED_ALGORITHMS = frozenset(["ES256"]) + +if es256 is not None: # pragma: NO COVER + _ALGORITHM_TO_VERIFIER_CLASS["ES256"] = es256.ES256Verifier + + +def encode(signer, payload, header=None, key_id=None): + """Make a signed JWT. + + Args: + signer (google.auth.crypt.Signer): The signer used to sign the JWT. + payload (Mapping[str, str]): The JWT payload. + header (Mapping[str, str]): Additional JWT header payload. + key_id (str): The key id to add to the JWT header. If the + signer has a key id it will be used as the default. If this is + specified it will override the signer's key id. + + Returns: + bytes: The encoded JWT. + """ + if header is None: + header = {} + + if key_id is None: + key_id = signer.key_id + + header.update({"typ": "JWT"}) + + if es256 is not None and isinstance(signer, es256.ES256Signer): + header.update({"alg": "ES256"}) + else: + header.update({"alg": "RS256"}) + + if key_id is not None: + header["kid"] = key_id + + segments = [ + _helpers.unpadded_urlsafe_b64encode(json.dumps(header).encode("utf-8")), + _helpers.unpadded_urlsafe_b64encode(json.dumps(payload).encode("utf-8")), + ] + + signing_input = b".".join(segments) + signature = signer.sign(signing_input) + segments.append(_helpers.unpadded_urlsafe_b64encode(signature)) + + return b".".join(segments) + + +def _decode_jwt_segment(encoded_section): + """Decodes a single JWT segment.""" + section_bytes = _helpers.padded_urlsafe_b64decode(encoded_section) + try: + return json.loads(section_bytes.decode("utf-8")) + except ValueError as caught_exc: + new_exc = ValueError("Can't parse segment: {0}".format(section_bytes)) + six.raise_from(new_exc, caught_exc) + + +def _unverified_decode(token): + """Decodes a token and does no verification. + + Args: + token (Union[str, bytes]): The encoded JWT. + + Returns: + Tuple[str, str, str, str]: header, payload, signed_section, and + signature. + + Raises: + ValueError: if there are an incorrect amount of segments in the token. + """ + token = _helpers.to_bytes(token) + + if token.count(b".") != 2: + raise ValueError("Wrong number of segments in token: {0}".format(token)) + + encoded_header, encoded_payload, signature = token.split(b".") + signed_section = encoded_header + b"." + encoded_payload + signature = _helpers.padded_urlsafe_b64decode(signature) + + # Parse segments + header = _decode_jwt_segment(encoded_header) + payload = _decode_jwt_segment(encoded_payload) + + return header, payload, signed_section, signature + + +def decode_header(token): + """Return the decoded header of a token. + + No verification is done. This is useful to extract the key id from + the header in order to acquire the appropriate certificate to verify + the token. + + Args: + token (Union[str, bytes]): the encoded JWT. + + Returns: + Mapping: The decoded JWT header. + """ + header, _, _, _ = _unverified_decode(token) + return header + + +def _verify_iat_and_exp(payload): + """Verifies the ``iat`` (Issued At) and ``exp`` (Expires) claims in a token + payload. + + Args: + payload (Mapping[str, str]): The JWT payload. + + Raises: + ValueError: if any checks failed. + """ + now = _helpers.datetime_to_secs(_helpers.utcnow()) + + # Make sure the iat and exp claims are present. + for key in ("iat", "exp"): + if key not in payload: + raise ValueError("Token does not contain required claim {}".format(key)) + + # Make sure the token wasn't issued in the future. + iat = payload["iat"] + # Err on the side of accepting a token that is slightly early to account + # for clock skew. + earliest = iat - _helpers.CLOCK_SKEW_SECS + if now < earliest: + raise ValueError("Token used too early, {} < {}".format(now, iat)) + + # Make sure the token wasn't issued in the past. + exp = payload["exp"] + # Err on the side of accepting a token that is slightly out of date + # to account for clow skew. + latest = exp + _helpers.CLOCK_SKEW_SECS + if latest < now: + raise ValueError("Token expired, {} < {}".format(latest, now)) + + +def decode(token, certs=None, verify=True, audience=None): + """Decode and verify a JWT. + + Args: + token (str): The encoded JWT. + certs (Union[str, bytes, Mapping[str, Union[str, bytes]]]): The + certificate used to validate the JWT signature. If bytes or string, + it must the the public key certificate in PEM format. If a mapping, + it must be a mapping of key IDs to public key certificates in PEM + format. The mapping must contain the same key ID that's specified + in the token's header. + verify (bool): Whether to perform signature and claim validation. + Verification is done by default. + audience (str): The audience claim, 'aud', that this JWT should + contain. If None then the JWT's 'aud' parameter is not verified. + + Returns: + Mapping[str, str]: The deserialized JSON payload in the JWT. + + Raises: + ValueError: if any verification checks failed. + """ + header, payload, signed_section, signature = _unverified_decode(token) + + if not verify: + return payload + + # Pluck the key id and algorithm from the header and make sure we have + # a verifier that can support it. + key_alg = header.get("alg") + key_id = header.get("kid") + + try: + verifier_cls = _ALGORITHM_TO_VERIFIER_CLASS[key_alg] + except KeyError as exc: + if key_alg in _CRYPTOGRAPHY_BASED_ALGORITHMS: + six.raise_from( + ValueError( + "The key algorithm {} requires the cryptography package " + "to be installed.".format(key_alg) + ), + exc, + ) + else: + six.raise_from( + ValueError("Unsupported signature algorithm {}".format(key_alg)), exc + ) + + # If certs is specified as a dictionary of key IDs to certificates, then + # use the certificate identified by the key ID in the token header. + if isinstance(certs, Mapping): + if key_id: + if key_id not in certs: + raise ValueError("Certificate for key id {} not found.".format(key_id)) + certs_to_check = [certs[key_id]] + # If there's no key id in the header, check against all of the certs. + else: + certs_to_check = certs.values() + else: + certs_to_check = certs + + # Verify that the signature matches the message. + if not crypt.verify_signature( + signed_section, signature, certs_to_check, verifier_cls + ): + raise ValueError("Could not verify token signature.") + + # Verify the issued at and created times in the payload. + _verify_iat_and_exp(payload) + + # Check audience. + if audience is not None: + claim_audience = payload.get("aud") + if audience != claim_audience: + raise ValueError( + "Token has wrong audience {}, expected {}".format( + claim_audience, audience + ) + ) + + return payload + + +class Credentials( + jwt.Credentials, + google.auth.credentials_async.Signing, + google.auth.credentials_async.Credentials, +): + """Credentials that use a JWT as the bearer token. + + These credentials require an "audience" claim. This claim identifies the + intended recipient of the bearer token. + + The constructor arguments determine the claims for the JWT that is + sent with requests. Usually, you'll construct these credentials with + one of the helper constructors as shown in the next section. + + To create JWT credentials using a Google service account private key + JSON file:: + + audience = 'https://pubsub.googleapis.com/google.pubsub.v1.Publisher' + credentials = jwt_async.Credentials.from_service_account_file( + 'service-account.json', + audience=audience) + + If you already have the service account file loaded and parsed:: + + service_account_info = json.load(open('service_account.json')) + credentials = jwt_async.Credentials.from_service_account_info( + service_account_info, + audience=audience) + + Both helper methods pass on arguments to the constructor, so you can + specify the JWT claims:: + + credentials = jwt_async.Credentials.from_service_account_file( + 'service-account.json', + audience=audience, + additional_claims={'meta': 'data'}) + + You can also construct the credentials directly if you have a + :class:`~google.auth.crypt.Signer` instance:: + + credentials = jwt_async.Credentials( + signer, + issuer='your-issuer', + subject='your-subject', + audience=audience) + + The claims are considered immutable. If you want to modify the claims, + you can easily create another instance using :meth:`with_claims`:: + + new_audience = ( + 'https://pubsub.googleapis.com/google.pubsub.v1.Subscriber') + new_credentials = credentials.with_claims(audience=new_audience) + """ + + +class OnDemandCredentials( + jwt.OnDemandCredentials, + google.auth.credentials_async.Signing, + google.auth.credentials_async.Credentials, +): + """On-demand JWT credentials. + + Like :class:`Credentials`, this class uses a JWT as the bearer token for + authentication. However, this class does not require the audience at + construction time. Instead, it will generate a new token on-demand for + each request using the request URI as the audience. It caches tokens + so that multiple requests to the same URI do not incur the overhead + of generating a new token every time. + + This behavior is especially useful for `gRPC`_ clients. A gRPC service may + have multiple audience and gRPC clients may not know all of the audiences + required for accessing a particular service. With these credentials, + no knowledge of the audiences is required ahead of time. + + .. _grpc: http://www.grpc.io/ + """ diff --git a/google/oauth2/_client_async.py b/google/oauth2/_client_async.py new file mode 100644 index 000000000..e498b8113 --- /dev/null +++ b/google/oauth2/_client_async.py @@ -0,0 +1,273 @@ +# Copyright 2016 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""OAuth 2.0 client. + +This is a client for interacting with an OAuth 2.0 authorization server's +token endpoint. + +For more information about the token endpoint, see +`Section 3.1 of rfc6749`_ + +.. _Section 3.1 of rfc6749: https://tools.ietf.org/html/rfc6749#section-3.2 +""" + +import datetime +import json + +import six +from six.moves import http_client +from six.moves import urllib + +from google.auth import _helpers +from google.auth import exceptions +from google.auth import jwt + +_URLENCODED_CONTENT_TYPE = "application/x-www-form-urlencoded" +_JWT_GRANT_TYPE = "urn:ietf:params:oauth:grant-type:jwt-bearer" +_REFRESH_GRANT_TYPE = "refresh_token" + + +def _handle_error_response(response_body): + """"Translates an error response into an exception. + + Args: + response_body (str): The decoded response data. + + Raises: + google.auth.exceptions.RefreshError + """ + try: + error_data = json.loads(response_body) + error_details = "{}: {}".format( + error_data["error"], error_data.get("error_description") + ) + # If no details could be extracted, use the response data. + except (KeyError, ValueError): + error_details = response_body + + raise exceptions.RefreshError(error_details, response_body) + + +def _parse_expiry(response_data): + """Parses the expiry field from a response into a datetime. + + Args: + response_data (Mapping): The JSON-parsed response data. + + Returns: + Optional[datetime]: The expiration or ``None`` if no expiration was + specified. + """ + expires_in = response_data.get("expires_in", None) + + if expires_in is not None: + return _helpers.utcnow() + datetime.timedelta(seconds=expires_in) + else: + return None + + +async def _token_endpoint_request(request, token_uri, body): + """Makes a request to the OAuth 2.0 authorization server's token endpoint. + + Args: + request (google.auth.transport.Request): A callable used to make + HTTP requests. + token_uri (str): The OAuth 2.0 authorizations server's token endpoint + URI. + body (Mapping[str, str]): The parameters to send in the request body. + + Returns: + Mapping[str, str]: The JSON-decoded response data. + + Raises: + google.auth.exceptions.RefreshError: If the token endpoint returned + an error. + """ + body = urllib.parse.urlencode(body).encode("utf-8") + headers = {"content-type": _URLENCODED_CONTENT_TYPE} + + retry = 0 + # retry to fetch token for maximum of two times if any internal failure + # occurs. + while True: + + response = await request( + method="POST", url=token_uri, headers=headers, body=body + ) + + """ + except exceptions.TransportError as caught_exc: + new_exc = exceptions.RefreshError(caught_exc) + six.raise_from(new_exc, caught_exc) + """ + + response_body1 = await response.data.read() + + response_body = ( + response_body1.decode("utf-8") + if hasattr(response_body1, "decode") + else response_body1 + ) + # CHANGE TO READ TO END OF STREAM + + response_data = json.loads(response_body) + + if response.status == http_client.OK: + break + else: + error_desc = response_data.get("error_description") or "" + error_code = response_data.get("error") or "" + if ( + any(e == "internal_failure" for e in (error_code, error_desc)) + and retry < 1 + ): + retry += 1 + continue + _handle_error_response(response_body) + + return response_data + + +async def jwt_grant(request, token_uri, assertion): + """Implements the JWT Profile for OAuth 2.0 Authorization Grants. + + For more details, see `rfc7523 section 4`_. + + Args: + request (google.auth.transport.Request): A callable used to make + HTTP requests. + token_uri (str): The OAuth 2.0 authorizations server's token endpoint + URI. + assertion (str): The OAuth 2.0 assertion. + + Returns: + Tuple[str, Optional[datetime], Mapping[str, str]]: The access token, + expiration, and additional data returned by the token endpoint. + + Raises: + google.auth.exceptions.RefreshError: If the token endpoint returned + an error. + + .. _rfc7523 section 4: https://tools.ietf.org/html/rfc7523#section-4 + """ + body = {"assertion": assertion, "grant_type": _JWT_GRANT_TYPE} + + response_data = await _token_endpoint_request(request, token_uri, body) + + try: + access_token = response_data["access_token"] + except KeyError as caught_exc: + new_exc = exceptions.RefreshError("No access token in response.", response_data) + six.raise_from(new_exc, caught_exc) + + expiry = _parse_expiry(response_data) + + return access_token, expiry, response_data + + +async def id_token_jwt_grant(request, token_uri, assertion): + """Implements the JWT Profile for OAuth 2.0 Authorization Grants, but + requests an OpenID Connect ID Token instead of an access token. + + This is a variant on the standard JWT Profile that is currently unique + to Google. This was added for the benefit of authenticating to services + that require ID Tokens instead of access tokens or JWT bearer tokens. + + Args: + request (google.auth.transport.Request): A callable used to make + HTTP requests. + token_uri (str): The OAuth 2.0 authorization server's token endpoint + URI. + assertion (str): JWT token signed by a service account. The token's + payload must include a ``target_audience`` claim. + + Returns: + Tuple[str, Optional[datetime], Mapping[str, str]]: + The (encoded) Open ID Connect ID Token, expiration, and additional + data returned by the endpoint. + + Raises: + google.auth.exceptions.RefreshError: If the token endpoint returned + an error. + """ + body = {"assertion": assertion, "grant_type": _JWT_GRANT_TYPE} + + response_data = await _token_endpoint_request(request, token_uri, body) + + try: + id_token = response_data["id_token"] + except KeyError as caught_exc: + new_exc = exceptions.RefreshError("No ID token in response.", response_data) + six.raise_from(new_exc, caught_exc) + + payload = jwt.decode(id_token, verify=False) + expiry = datetime.datetime.utcfromtimestamp(payload["exp"]) + + return id_token, expiry, response_data + + +async def refresh_grant( + request, token_uri, refresh_token, client_id, client_secret, scopes=None +): + """Implements the OAuth 2.0 refresh token grant. + + For more details, see `rfc678 section 6`_. + + Args: + request (google.auth.transport.Request): A callable used to make + HTTP requests. + token_uri (str): The OAuth 2.0 authorizations server's token endpoint + URI. + refresh_token (str): The refresh token to use to get a new access + token. + client_id (str): The OAuth 2.0 application's client ID. + client_secret (str): The Oauth 2.0 appliaction's client secret. + scopes (Optional(Sequence[str])): Scopes to request. If present, all + scopes must be authorized for the refresh token. Useful if refresh + token has a wild card scope (e.g. + 'https://www.googleapis.com/auth/any-api'). + + Returns: + Tuple[str, Optional[str], Optional[datetime], Mapping[str, str]]: The + access token, new refresh token, expiration, and additional data + returned by the token endpoint. + + Raises: + google.auth.exceptions.RefreshError: If the token endpoint returned + an error. + + .. _rfc6748 section 6: https://tools.ietf.org/html/rfc6749#section-6 + """ + body = { + "grant_type": _REFRESH_GRANT_TYPE, + "client_id": client_id, + "client_secret": client_secret, + "refresh_token": refresh_token, + } + if scopes: + body["scope"] = " ".join(scopes) + + response_data = await _token_endpoint_request(request, token_uri, body) + + try: + access_token = response_data["access_token"] + except KeyError as caught_exc: + new_exc = exceptions.RefreshError("No access token in response.", response_data) + six.raise_from(new_exc, caught_exc) + + refresh_token = response_data.get("refresh_token", refresh_token) + expiry = _parse_expiry(response_data) + + return access_token, refresh_token, expiry, response_data diff --git a/google/oauth2/credentials_async.py b/google/oauth2/credentials_async.py new file mode 100644 index 000000000..b45feddc6 --- /dev/null +++ b/google/oauth2/credentials_async.py @@ -0,0 +1,107 @@ +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""OAuth 2.0 Credentials. + +This module provides credentials based on OAuth 2.0 access and refresh tokens. +These credentials usually access resources on behalf of a user (resource +owner). + +Specifically, this is intended to use access tokens acquired using the +`Authorization Code grant`_ and can refresh those tokens using a +optional `refresh token`_. + +Obtaining the initial access and refresh token is outside of the scope of this +module. Consult `rfc6749 section 4.1`_ for complete details on the +Authorization Code grant flow. + +.. _Authorization Code grant: https://tools.ietf.org/html/rfc6749#section-1.3.1 +.. _refresh token: https://tools.ietf.org/html/rfc6749#section-6 +.. _rfc6749 section 4.1: https://tools.ietf.org/html/rfc6749#section-4.1 +""" + +from google.auth import _helpers +from google.auth import credentials_async as credentials +from google.auth import exceptions +from google.oauth2 import _client_async as _client +from google.oauth2 import credentials as oauth2_credentials + + +# The Google OAuth 2.0 token endpoint. Used for authorized user credentials. +_GOOGLE_OAUTH2_TOKEN_ENDPOINT = "https://oauth2.googleapis.com/token" + + +class Credentials(oauth2_credentials.Credentials): + """Credentials using OAuth 2.0 access and refresh tokens. + + The credentials are considered immutable. If you want to modify the + quota project, use :meth:`with_quota_project` or :: + + credentials = credentials.with_quota_project('myproject-123) + """ + + @_helpers.copy_docstring(credentials.Credentials) + async def refresh(self, request): + if ( + self._refresh_token is None + or self._token_uri is None + or self._client_id is None + or self._client_secret is None + ): + raise exceptions.RefreshError( + "The credentials do not contain the necessary fields need to " + "refresh the access token. You must specify refresh_token, " + "token_uri, client_id, and client_secret." + ) + + access_token, refresh_token, expiry, grant_response = await _client.refresh_grant( + request, + self._token_uri, + self._refresh_token, + self._client_id, + self._client_secret, + self._scopes, + ) + + self.token = access_token + self.expiry = expiry + self._refresh_token = refresh_token + self._id_token = grant_response.get("id_token") + + if self._scopes and "scopes" in grant_response: + requested_scopes = frozenset(self._scopes) + granted_scopes = frozenset(grant_response["scopes"].split()) + scopes_requested_but_not_granted = requested_scopes - granted_scopes + if scopes_requested_but_not_granted: + raise exceptions.RefreshError( + "Not all requested scopes were granted by the " + "authorization server, missing scopes {}.".format( + ", ".join(scopes_requested_but_not_granted) + ) + ) + + +class UserAccessTokenCredentials(oauth2_credentials.UserAccessTokenCredentials): + """Access token credentials for user account. + + Obtain the access token for a given user account or the current active + user account with the ``gcloud auth print-access-token`` command. + + Args: + account (Optional[str]): Account to get the access token for. If not + specified, the current active account will be used. + quota_project_id (Optional[str]): The project ID used for quota + and billing. + + """ diff --git a/google/oauth2/service_account_async.py b/google/oauth2/service_account_async.py new file mode 100644 index 000000000..333e06de2 --- /dev/null +++ b/google/oauth2/service_account_async.py @@ -0,0 +1,134 @@ +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Service Accounts: JSON Web Token (JWT) Profile for OAuth 2.0 + +NOTE: This file adds asynchronous refresh methods to both credentials +classes, and therefore async/await syntax is required when calling this +method when using service account credentials with asynchronous functionality. +Otherwise, all other methods are inherited from the regular service account +credentials file google.oauth2.service_account + +""" + +from google.auth import _helpers +from google.auth import credentials_async +from google.oauth2 import _client_async +from google.oauth2 import service_account + +_DEFAULT_TOKEN_LIFETIME_SECS = 3600 # 1 hour in seconds + + +class Credentials( + service_account.Credentials, credentials_async.Scoped, credentials_async.Credentials +): + """Service account credentials + + Usually, you'll create these credentials with one of the helper + constructors. To create credentials using a Google service account + private key JSON file:: + + credentials = service_account_async.Credentials.from_service_account_file( + 'service-account.json') + + Or if you already have the service account file loaded:: + + service_account_info = json.load(open('service_account.json')) + credentials = service_account_async.Credentials.from_service_account_info( + service_account_info) + + Both helper methods pass on arguments to the constructor, so you can + specify additional scopes and a subject if necessary:: + + credentials = service_account_async.Credentials.from_service_account_file( + 'service-account.json', + scopes=['email'], + subject='user@example.com') + + The credentials are considered immutable. If you want to modify the scopes + or the subject used for delegation, use :meth:`with_scopes` or + :meth:`with_subject`:: + + scoped_credentials = credentials.with_scopes(['email']) + delegated_credentials = credentials.with_subject(subject) + + To add a quota project, use :meth:`with_quota_project`:: + + credentials = credentials.with_quota_project('myproject-123') + """ + + @_helpers.copy_docstring(credentials_async.Credentials) + async def refresh(self, request): + assertion = self._make_authorization_grant_assertion() + access_token, expiry, _ = await _client_async.jwt_grant( + request, self._token_uri, assertion + ) + self.token = access_token + self.expiry = expiry + + +class IDTokenCredentials( + service_account.IDTokenCredentials, + credentials_async.Signing, + credentials_async.Credentials, +): + """Open ID Connect ID Token-based service account credentials. + + These credentials are largely similar to :class:`.Credentials`, but instead + of using an OAuth 2.0 Access Token as the bearer token, they use an Open + ID Connect ID Token as the bearer token. These credentials are useful when + communicating to services that require ID Tokens and can not accept access + tokens. + + Usually, you'll create these credentials with one of the helper + constructors. To create credentials using a Google service account + private key JSON file:: + + credentials = ( + service_account_async.IDTokenCredentials.from_service_account_file( + 'service-account.json')) + + Or if you already have the service account file loaded:: + + service_account_info = json.load(open('service_account.json')) + credentials = ( + service_account_async.IDTokenCredentials.from_service_account_info( + service_account_info)) + + Both helper methods pass on arguments to the constructor, so you can + specify additional scopes and a subject if necessary:: + + credentials = ( + service_account_async.IDTokenCredentials.from_service_account_file( + 'service-account.json', + scopes=['email'], + subject='user@example.com')) +` + The credentials are considered immutable. If you want to modify the scopes + or the subject used for delegation, use :meth:`with_scopes` or + :meth:`with_subject`:: + + scoped_credentials = credentials.with_scopes(['email']) + delegated_credentials = credentials.with_subject(subject) + + """ + + @_helpers.copy_docstring(credentials_async.Credentials) + async def refresh(self, request): + assertion = self._make_authorization_grant_assertion() + access_token, expiry, _ = await _client_async.id_token_jwt_grant( + request, self._token_uri, assertion + ) + self.token = access_token + self.expiry = expiry diff --git a/tests_async/__init__.py b/tests_async/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests_async/conftest.py b/tests_async/conftest.py new file mode 100644 index 000000000..6989a6375 --- /dev/null +++ b/tests_async/conftest.py @@ -0,0 +1,51 @@ +# Copyright 2016 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import sys + +import mock +import pytest + + +def pytest_configure(): + """Load public certificate and private key.""" + pytest.data_dir = os.path.join( + os.path.abspath(os.path.join(__file__, "../..")), "tests/data" + ) + + with open(os.path.join(pytest.data_dir, "privatekey.pem"), "rb") as fh: + pytest.private_key_bytes = fh.read() + + with open(os.path.join(pytest.data_dir, "public_cert.pem"), "rb") as fh: + pytest.public_cert_bytes = fh.read() + + +@pytest.fixture +def mock_non_existent_module(monkeypatch): + """Mocks a non-existing module in sys.modules. + + Additionally mocks any non-existing modules specified in the dotted path. + """ + + def _mock_non_existent_module(path): + parts = path.split(".") + partial = [] + for part in parts: + partial.append(part) + current_module = ".".join(partial) + if current_module not in sys.modules: + monkeypatch.setitem(sys.modules, current_module, mock.MagicMock()) + + return _mock_non_existent_module diff --git a/tests_async/oauth2/test__client_async.py b/tests_async/oauth2/test__client_async.py new file mode 100644 index 000000000..fd5c17dd5 --- /dev/null +++ b/tests_async/oauth2/test__client_async.py @@ -0,0 +1,313 @@ +# Copyright 2016 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import datetime +import json +import os + +import mock +import pytest +import six +from six.moves import http_client +from six.moves import urllib + +from google.auth import _helpers +from google.auth import crypt +from google.auth import exceptions +from google.auth import jwt_async as jwt +from google.oauth2 import _client_async as _client + + +DATA_DIR = os.path.join( + os.path.abspath(os.path.join(__file__, "../../..")), "tests/data" +) + +with open(os.path.join(DATA_DIR, "privatekey.pem"), "rb") as fh: + PRIVATE_KEY_BYTES = fh.read() + +SIGNER = crypt.RSASigner.from_string(PRIVATE_KEY_BYTES, "1") + +SCOPES_AS_LIST = [ + "https://www.googleapis.com/auth/pubsub", + "https://www.googleapis.com/auth/logging.write", +] +SCOPES_AS_STRING = ( + "https://www.googleapis.com/auth/pubsub" + " https://www.googleapis.com/auth/logging.write" +) + + +def test__handle_error_response(): + response_data = json.dumps({"error": "help", "error_description": "I'm alive"}) + + with pytest.raises(exceptions.RefreshError) as excinfo: + _client._handle_error_response(response_data) + + assert excinfo.match(r"help: I\'m alive") + + +def test__handle_error_response_non_json(): + response_data = "Help, I'm alive" + + with pytest.raises(exceptions.RefreshError) as excinfo: + _client._handle_error_response(response_data) + + assert excinfo.match(r"Help, I\'m alive") + + +@mock.patch("google.auth._helpers.utcnow", return_value=datetime.datetime.min) +def test__parse_expiry(unused_utcnow): + result = _client._parse_expiry({"expires_in": 500}) + assert result == datetime.datetime.min + datetime.timedelta(seconds=500) + + +def test__parse_expiry_none(): + assert _client._parse_expiry({}) is None + + +def make_request(response_data, status=http_client.OK): + response = mock.AsyncMock(spec=["transport.Response"]) + response.status = status + data = json.dumps(response_data).encode("utf-8") + response.data = mock.AsyncMock(spec=["__call__", "read"]) + response.data.read = mock.AsyncMock(spec=["__call__"], return_value=data) + request = mock.AsyncMock(spec=["transport.Request"]) + request.return_value = response + return request + + +@pytest.mark.asyncio +async def test__token_endpoint_request(): + + request = make_request({"test": "response"}) + + result = await _client._token_endpoint_request( + request, "http://example.com", {"test": "params"} + ) + + # Check request call + request.assert_called_with( + method="POST", + url="http://example.com", + headers={"content-type": "application/x-www-form-urlencoded"}, + body="test=params".encode("utf-8"), + ) + + # Check result + assert result == {"test": "response"} + + +@pytest.mark.asyncio +async def test__token_endpoint_request_error(): + request = make_request({}, status=http_client.BAD_REQUEST) + + with pytest.raises(exceptions.RefreshError): + await _client._token_endpoint_request(request, "http://example.com", {}) + + +@pytest.mark.asyncio +async def test__token_endpoint_request_internal_failure_error(): + request = make_request( + {"error_description": "internal_failure"}, status=http_client.BAD_REQUEST + ) + + with pytest.raises(exceptions.RefreshError): + await _client._token_endpoint_request( + request, "http://example.com", {"error_description": "internal_failure"} + ) + + request = make_request( + {"error": "internal_failure"}, status=http_client.BAD_REQUEST + ) + + with pytest.raises(exceptions.RefreshError): + await _client._token_endpoint_request( + request, "http://example.com", {"error": "internal_failure"} + ) + + +def verify_request_params(request, params): + request_body = request.call_args[1]["body"].decode("utf-8") + request_params = urllib.parse.parse_qs(request_body) + + for key, value in six.iteritems(params): + assert request_params[key][0] == value + + +@mock.patch("google.auth._helpers.utcnow", return_value=datetime.datetime.min) +@pytest.mark.asyncio +async def test_jwt_grant(utcnow): + request = make_request( + {"access_token": "token", "expires_in": 500, "extra": "data"} + ) + + token, expiry, extra_data = await _client.jwt_grant( + request, "http://example.com", "assertion_value" + ) + + # Check request call + verify_request_params( + request, {"grant_type": _client._JWT_GRANT_TYPE, "assertion": "assertion_value"} + ) + + # Check result + assert token == "token" + assert expiry == utcnow() + datetime.timedelta(seconds=500) + assert extra_data["extra"] == "data" + + +@pytest.mark.asyncio +async def test_jwt_grant_no_access_token(): + request = make_request( + { + # No access token. + "expires_in": 500, + "extra": "data", + } + ) + + with pytest.raises(exceptions.RefreshError): + await _client.jwt_grant(request, "http://example.com", "assertion_value") + + +@pytest.mark.asyncio +async def test_id_token_jwt_grant(): + now = _helpers.utcnow() + id_token_expiry = _helpers.datetime_to_secs(now) + id_token = jwt.encode(SIGNER, {"exp": id_token_expiry}).decode("utf-8") + request = make_request({"id_token": id_token, "extra": "data"}) + + token, expiry, extra_data = await _client.id_token_jwt_grant( + request, "http://example.com", "assertion_value" + ) + + # Check request call + verify_request_params( + request, {"grant_type": _client._JWT_GRANT_TYPE, "assertion": "assertion_value"} + ) + + # Check result + assert token == id_token + # JWT does not store microseconds + now = now.replace(microsecond=0) + assert expiry == now + assert extra_data["extra"] == "data" + + +@pytest.mark.asyncio +async def test_id_token_jwt_grant_no_access_token(): + request = make_request( + { + # No access token. + "expires_in": 500, + "extra": "data", + } + ) + + with pytest.raises(exceptions.RefreshError): + await _client.id_token_jwt_grant( + request, "http://example.com", "assertion_value" + ) + + +@mock.patch("google.auth._helpers.utcnow", return_value=datetime.datetime.min) +@pytest.mark.asyncio +async def test_refresh_grant(unused_utcnow): + request = make_request( + { + "access_token": "token", + "refresh_token": "new_refresh_token", + "expires_in": 500, + "extra": "data", + } + ) + + token, refresh_token, expiry, extra_data = await _client.refresh_grant( + request, "http://example.com", "refresh_token", "client_id", "client_secret" + ) + + # Check request call + verify_request_params( + request, + { + "grant_type": _client._REFRESH_GRANT_TYPE, + "refresh_token": "refresh_token", + "client_id": "client_id", + "client_secret": "client_secret", + }, + ) + + # Check result + assert token == "token" + assert refresh_token == "new_refresh_token" + assert expiry == datetime.datetime.min + datetime.timedelta(seconds=500) + assert extra_data["extra"] == "data" + + +@mock.patch("google.auth._helpers.utcnow", return_value=datetime.datetime.min) +@pytest.mark.asyncio +async def test_refresh_grant_with_scopes(unused_utcnow): + request = make_request( + { + "access_token": "token", + "refresh_token": "new_refresh_token", + "expires_in": 500, + "extra": "data", + "scope": SCOPES_AS_STRING, + } + ) + + token, refresh_token, expiry, extra_data = await _client.refresh_grant( + request, + "http://example.com", + "refresh_token", + "client_id", + "client_secret", + SCOPES_AS_LIST, + ) + + # Check request call. + verify_request_params( + request, + { + "grant_type": _client._REFRESH_GRANT_TYPE, + "refresh_token": "refresh_token", + "client_id": "client_id", + "client_secret": "client_secret", + "scope": SCOPES_AS_STRING, + }, + ) + + # Check result. + assert token == "token" + assert refresh_token == "new_refresh_token" + assert expiry == datetime.datetime.min + datetime.timedelta(seconds=500) + assert extra_data["extra"] == "data" + + +@pytest.mark.asyncio +async def test_refresh_grant_no_access_token(): + request = make_request( + { + # No access token. + "refresh_token": "new_refresh_token", + "expires_in": 500, + "extra": "data", + } + ) + + with pytest.raises(exceptions.RefreshError): + await _client.refresh_grant( + request, "http://example.com", "refresh_token", "client_id", "client_secret" + ) diff --git a/tests_async/oauth2/test_credentials_async.py b/tests_async/oauth2/test_credentials_async.py new file mode 100644 index 000000000..a3614b3bd --- /dev/null +++ b/tests_async/oauth2/test_credentials_async.py @@ -0,0 +1,482 @@ +# Copyright 2016 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import datetime +import json +import os +import pickle +import sys + +import mock +import pytest + +from google.auth import _helpers +from google.auth import exceptions +from google.oauth2 import credentials_async as credentials + +DATA_DIR = os.path.join( + os.path.abspath(os.path.join(__file__, "../../..")), "tests/data" +) + +AUTH_USER_JSON_FILE = os.path.join(DATA_DIR, "authorized_user.json") + +with open(AUTH_USER_JSON_FILE, "r") as fh: + AUTH_USER_INFO = json.load(fh) + + +class TestCredentials: + + TOKEN_URI = "https://example.com/oauth2/token" + REFRESH_TOKEN = "refresh_token" + CLIENT_ID = "client_id" + CLIENT_SECRET = "client_secret" + + @classmethod + def make_credentials(cls): + return credentials.Credentials( + token=None, + refresh_token=cls.REFRESH_TOKEN, + token_uri=cls.TOKEN_URI, + client_id=cls.CLIENT_ID, + client_secret=cls.CLIENT_SECRET, + ) + + def test_default_state(self): + credentials = self.make_credentials() + assert not credentials.valid + # Expiration hasn't been set yet + assert not credentials.expired + # Scopes aren't required for these credentials + assert not credentials.requires_scopes + # Test properties + assert credentials.refresh_token == self.REFRESH_TOKEN + assert credentials.token_uri == self.TOKEN_URI + assert credentials.client_id == self.CLIENT_ID + assert credentials.client_secret == self.CLIENT_SECRET + + @mock.patch("google.oauth2._client_async.refresh_grant", autospec=True) + @mock.patch( + "google.auth._helpers.utcnow", + return_value=datetime.datetime.min + _helpers.CLOCK_SKEW, + ) + @pytest.mark.asyncio + async def test_refresh_success(self, unused_utcnow, refresh_grant): + token = "token" + expiry = _helpers.utcnow() + datetime.timedelta(seconds=500) + grant_response = {"id_token": mock.sentinel.id_token} + refresh_grant.return_value = ( + # Access token + token, + # New refresh token + None, + # Expiry, + expiry, + # Extra data + grant_response, + ) + + request = mock.AsyncMock(spec=["transport.Request"]) + credentials = self.make_credentials() + + # Refresh credentials + await credentials.refresh(request) + + # Check jwt grant call. + refresh_grant.assert_called_with( + request, + self.TOKEN_URI, + self.REFRESH_TOKEN, + self.CLIENT_ID, + self.CLIENT_SECRET, + None, + ) + + # Check that the credentials have the token and expiry + assert credentials.token == token + assert credentials.expiry == expiry + assert credentials.id_token == mock.sentinel.id_token + + # Check that the credentials are valid (have a token and are not + # expired) + assert credentials.valid + + @pytest.mark.asyncio + async def test_refresh_no_refresh_token(self): + request = mock.AsyncMock(spec=["transport.Request"]) + credentials_ = credentials.Credentials(token=None, refresh_token=None) + + with pytest.raises(exceptions.RefreshError, match="necessary fields"): + await credentials_.refresh(request) + + request.assert_not_called() + + @mock.patch("google.oauth2._client_async.refresh_grant", autospec=True) + @mock.patch( + "google.auth._helpers.utcnow", + return_value=datetime.datetime.min + _helpers.CLOCK_SKEW, + ) + @pytest.mark.asyncio + async def test_credentials_with_scopes_requested_refresh_success( + self, unused_utcnow, refresh_grant + ): + scopes = ["email", "profile"] + token = "token" + expiry = _helpers.utcnow() + datetime.timedelta(seconds=500) + grant_response = {"id_token": mock.sentinel.id_token} + refresh_grant.return_value = ( + # Access token + token, + # New refresh token + None, + # Expiry, + expiry, + # Extra data + grant_response, + ) + + request = mock.AsyncMock(spec=["transport.Request"]) + creds = credentials.Credentials( + token=None, + refresh_token=self.REFRESH_TOKEN, + token_uri=self.TOKEN_URI, + client_id=self.CLIENT_ID, + client_secret=self.CLIENT_SECRET, + scopes=scopes, + ) + + # Refresh credentials + await creds.refresh(request) + + # Check jwt grant call. + refresh_grant.assert_called_with( + request, + self.TOKEN_URI, + self.REFRESH_TOKEN, + self.CLIENT_ID, + self.CLIENT_SECRET, + scopes, + ) + + # Check that the credentials have the token and expiry + assert creds.token == token + assert creds.expiry == expiry + assert creds.id_token == mock.sentinel.id_token + assert creds.has_scopes(scopes) + + # Check that the credentials are valid (have a token and are not + # expired.) + assert creds.valid + + @mock.patch("google.oauth2._client_async.refresh_grant", autospec=True) + @mock.patch( + "google.auth._helpers.utcnow", + return_value=datetime.datetime.min + _helpers.CLOCK_SKEW, + ) + @pytest.mark.asyncio + async def test_credentials_with_scopes_returned_refresh_success( + self, unused_utcnow, refresh_grant + ): + scopes = ["email", "profile"] + token = "token" + expiry = _helpers.utcnow() + datetime.timedelta(seconds=500) + grant_response = { + "id_token": mock.sentinel.id_token, + "scopes": " ".join(scopes), + } + refresh_grant.return_value = ( + # Access token + token, + # New refresh token + None, + # Expiry, + expiry, + # Extra data + grant_response, + ) + + request = mock.AsyncMock(spec=["transport.Request"]) + creds = credentials.Credentials( + token=None, + refresh_token=self.REFRESH_TOKEN, + token_uri=self.TOKEN_URI, + client_id=self.CLIENT_ID, + client_secret=self.CLIENT_SECRET, + scopes=scopes, + ) + + # Refresh credentials + await creds.refresh(request) + + # Check jwt grant call. + refresh_grant.assert_called_with( + request, + self.TOKEN_URI, + self.REFRESH_TOKEN, + self.CLIENT_ID, + self.CLIENT_SECRET, + scopes, + ) + + # Check that the credentials have the token and expiry + assert creds.token == token + assert creds.expiry == expiry + assert creds.id_token == mock.sentinel.id_token + assert creds.has_scopes(scopes) + + # Check that the credentials are valid (have a token and are not + # expired.) + assert creds.valid + + @mock.patch("google.oauth2._client_async.refresh_grant", autospec=True) + @mock.patch( + "google.auth._helpers.utcnow", + return_value=datetime.datetime.min + _helpers.CLOCK_SKEW, + ) + @pytest.mark.asyncio + async def test_credentials_with_scopes_refresh_failure_raises_refresh_error( + self, unused_utcnow, refresh_grant + ): + scopes = ["email", "profile"] + scopes_returned = ["email"] + token = "token" + expiry = _helpers.utcnow() + datetime.timedelta(seconds=500) + grant_response = { + "id_token": mock.sentinel.id_token, + "scopes": " ".join(scopes_returned), + } + refresh_grant.return_value = ( + # Access token + token, + # New refresh token + None, + # Expiry, + expiry, + # Extra data + grant_response, + ) + + request = mock.AsyncMock(spec=["transport.Request"]) + creds = credentials.Credentials( + token=None, + refresh_token=self.REFRESH_TOKEN, + token_uri=self.TOKEN_URI, + client_id=self.CLIENT_ID, + client_secret=self.CLIENT_SECRET, + scopes=scopes, + ) + + # Refresh credentials + with pytest.raises( + exceptions.RefreshError, match="Not all requested scopes were granted" + ): + await creds.refresh(request) + + # Check jwt grant call. + refresh_grant.assert_called_with( + request, + self.TOKEN_URI, + self.REFRESH_TOKEN, + self.CLIENT_ID, + self.CLIENT_SECRET, + scopes, + ) + + # Check that the credentials have the token and expiry + assert creds.token == token + assert creds.expiry == expiry + assert creds.id_token == mock.sentinel.id_token + assert creds.has_scopes(scopes) + + # Check that the credentials are valid (have a token and are not + # expired.) + assert creds.valid + + def test_apply_with_quota_project_id(self): + creds = credentials.Credentials( + token="token", + refresh_token=self.REFRESH_TOKEN, + token_uri=self.TOKEN_URI, + client_id=self.CLIENT_ID, + client_secret=self.CLIENT_SECRET, + quota_project_id="quota-project-123", + ) + + headers = {} + creds.apply(headers) + assert headers["x-goog-user-project"] == "quota-project-123" + + def test_apply_with_no_quota_project_id(self): + creds = credentials.Credentials( + token="token", + refresh_token=self.REFRESH_TOKEN, + token_uri=self.TOKEN_URI, + client_id=self.CLIENT_ID, + client_secret=self.CLIENT_SECRET, + ) + + headers = {} + creds.apply(headers) + assert "x-goog-user-project" not in headers + + def test_with_quota_project(self): + creds = credentials.Credentials( + token="token", + refresh_token=self.REFRESH_TOKEN, + token_uri=self.TOKEN_URI, + client_id=self.CLIENT_ID, + client_secret=self.CLIENT_SECRET, + quota_project_id="quota-project-123", + ) + + new_creds = creds.with_quota_project("new-project-456") + assert new_creds.quota_project_id == "new-project-456" + headers = {} + creds.apply(headers) + assert "x-goog-user-project" in headers + + def test_from_authorized_user_info(self): + info = AUTH_USER_INFO.copy() + + creds = credentials.Credentials.from_authorized_user_info(info) + assert creds.client_secret == info["client_secret"] + assert creds.client_id == info["client_id"] + assert creds.refresh_token == info["refresh_token"] + assert creds.token_uri == credentials._GOOGLE_OAUTH2_TOKEN_ENDPOINT + assert creds.scopes is None + + scopes = ["email", "profile"] + creds = credentials.Credentials.from_authorized_user_info(info, scopes) + assert creds.client_secret == info["client_secret"] + assert creds.client_id == info["client_id"] + assert creds.refresh_token == info["refresh_token"] + assert creds.token_uri == credentials._GOOGLE_OAUTH2_TOKEN_ENDPOINT + assert creds.scopes == scopes + + def test_from_authorized_user_file(self): + info = AUTH_USER_INFO.copy() + + creds = credentials.Credentials.from_authorized_user_file(AUTH_USER_JSON_FILE) + assert creds.client_secret == info["client_secret"] + assert creds.client_id == info["client_id"] + assert creds.refresh_token == info["refresh_token"] + assert creds.token_uri == credentials._GOOGLE_OAUTH2_TOKEN_ENDPOINT + assert creds.scopes is None + + scopes = ["email", "profile"] + creds = credentials.Credentials.from_authorized_user_file( + AUTH_USER_JSON_FILE, scopes + ) + assert creds.client_secret == info["client_secret"] + assert creds.client_id == info["client_id"] + assert creds.refresh_token == info["refresh_token"] + assert creds.token_uri == credentials._GOOGLE_OAUTH2_TOKEN_ENDPOINT + assert creds.scopes == scopes + + def test_to_json(self): + info = AUTH_USER_INFO.copy() + creds = credentials.Credentials.from_authorized_user_info(info) + + # Test with no `strip` arg + json_output = creds.to_json() + json_asdict = json.loads(json_output) + assert json_asdict.get("token") == creds.token + assert json_asdict.get("refresh_token") == creds.refresh_token + assert json_asdict.get("token_uri") == creds.token_uri + assert json_asdict.get("client_id") == creds.client_id + assert json_asdict.get("scopes") == creds.scopes + assert json_asdict.get("client_secret") == creds.client_secret + + # Test with a `strip` arg + json_output = creds.to_json(strip=["client_secret"]) + json_asdict = json.loads(json_output) + assert json_asdict.get("token") == creds.token + assert json_asdict.get("refresh_token") == creds.refresh_token + assert json_asdict.get("token_uri") == creds.token_uri + assert json_asdict.get("client_id") == creds.client_id + assert json_asdict.get("scopes") == creds.scopes + assert json_asdict.get("client_secret") is None + + def test_pickle_and_unpickle(self): + creds = self.make_credentials() + unpickled = pickle.loads(pickle.dumps(creds)) + + # make sure attributes aren't lost during pickling + assert list(creds.__dict__).sort() == list(unpickled.__dict__).sort() + + for attr in list(creds.__dict__): + assert getattr(creds, attr) == getattr(unpickled, attr) + + def test_pickle_with_missing_attribute(self): + creds = self.make_credentials() + + # remove an optional attribute before pickling + # this mimics a pickle created with a previous class definition with + # fewer attributes + del creds.__dict__["_quota_project_id"] + + unpickled = pickle.loads(pickle.dumps(creds)) + + # Attribute should be initialized by `__setstate__` + assert unpickled.quota_project_id is None + + # pickles are not compatible across versions + @pytest.mark.skipif( + sys.version_info < (3, 5), + reason="pickle file can only be loaded with Python >= 3.5", + ) + def test_unpickle_old_credentials_pickle(self): + # make sure a credentials file pickled with an older + # library version (google-auth==1.5.1) can be unpickled + with open( + os.path.join(DATA_DIR, "old_oauth_credentials_py3.pickle"), "rb" + ) as f: + credentials = pickle.load(f) + assert credentials.quota_project_id is None + + +class TestUserAccessTokenCredentials(object): + def test_instance(self): + cred = credentials.UserAccessTokenCredentials() + assert cred._account is None + + cred = cred.with_account("account") + assert cred._account == "account" + + @mock.patch("google.auth._cloud_sdk.get_auth_access_token", autospec=True) + def test_refresh(self, get_auth_access_token): + get_auth_access_token.return_value = "access_token" + cred = credentials.UserAccessTokenCredentials() + cred.refresh(None) + assert cred.token == "access_token" + + def test_with_quota_project(self): + cred = credentials.UserAccessTokenCredentials() + quota_project_cred = cred.with_quota_project("project-foo") + + assert quota_project_cred._quota_project_id == "project-foo" + assert quota_project_cred._account == cred._account + + @mock.patch( + "google.oauth2.credentials_async.UserAccessTokenCredentials.apply", + autospec=True, + ) + @mock.patch( + "google.oauth2.credentials_async.UserAccessTokenCredentials.refresh", + autospec=True, + ) + def test_before_request(self, refresh, apply): + cred = credentials.UserAccessTokenCredentials() + cred.before_request(mock.Mock(), "GET", "https://example.com", {}) + refresh.assert_called() + apply.assert_called() diff --git a/tests_async/oauth2/test_service_account_async.py b/tests_async/oauth2/test_service_account_async.py new file mode 100644 index 000000000..22a1876da --- /dev/null +++ b/tests_async/oauth2/test_service_account_async.py @@ -0,0 +1,367 @@ +# Copyright 2016 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import datetime +import json +import os + +import mock +import pytest + +from google.auth import _helpers +from google.auth import crypt +from google.auth import jwt +from google.auth import transport +from google.oauth2 import service_account_async as service_account + + +DATA_DIR = os.path.join( + os.path.abspath(os.path.join(__file__, "../../..")), "tests/data" +) + +with open(os.path.join(DATA_DIR, "privatekey.pem"), "rb") as fh: + PRIVATE_KEY_BYTES = fh.read() + +with open(os.path.join(DATA_DIR, "public_cert.pem"), "rb") as fh: + PUBLIC_CERT_BYTES = fh.read() + +with open(os.path.join(DATA_DIR, "other_cert.pem"), "rb") as fh: + OTHER_CERT_BYTES = fh.read() + +SERVICE_ACCOUNT_JSON_FILE = os.path.join(DATA_DIR, "service_account.json") + +with open(SERVICE_ACCOUNT_JSON_FILE, "r") as fh: + SERVICE_ACCOUNT_INFO = json.load(fh) + +SIGNER = crypt.RSASigner.from_string(PRIVATE_KEY_BYTES, "1") + + +class TestCredentials(object): + SERVICE_ACCOUNT_EMAIL = "service-account@example.com" + TOKEN_URI = "https://example.com/oauth2/token" + + @classmethod + def make_credentials(cls): + return service_account.Credentials( + SIGNER, cls.SERVICE_ACCOUNT_EMAIL, cls.TOKEN_URI + ) + + def test_from_service_account_info(self): + credentials = service_account.Credentials.from_service_account_info( + SERVICE_ACCOUNT_INFO + ) + + assert credentials._signer.key_id == SERVICE_ACCOUNT_INFO["private_key_id"] + assert credentials.service_account_email == SERVICE_ACCOUNT_INFO["client_email"] + assert credentials._token_uri == SERVICE_ACCOUNT_INFO["token_uri"] + + def test_from_service_account_info_args(self): + info = SERVICE_ACCOUNT_INFO.copy() + scopes = ["email", "profile"] + subject = "subject" + additional_claims = {"meta": "data"} + + credentials = service_account.Credentials.from_service_account_info( + info, scopes=scopes, subject=subject, additional_claims=additional_claims + ) + + assert credentials.service_account_email == info["client_email"] + assert credentials.project_id == info["project_id"] + assert credentials._signer.key_id == info["private_key_id"] + assert credentials._token_uri == info["token_uri"] + assert credentials._scopes == scopes + assert credentials._subject == subject + assert credentials._additional_claims == additional_claims + + def test_from_service_account_file(self): + info = SERVICE_ACCOUNT_INFO.copy() + + credentials = service_account.Credentials.from_service_account_file( + SERVICE_ACCOUNT_JSON_FILE + ) + + assert credentials.service_account_email == info["client_email"] + assert credentials.project_id == info["project_id"] + assert credentials._signer.key_id == info["private_key_id"] + assert credentials._token_uri == info["token_uri"] + + def test_from_service_account_file_args(self): + info = SERVICE_ACCOUNT_INFO.copy() + scopes = ["email", "profile"] + subject = "subject" + additional_claims = {"meta": "data"} + + credentials = service_account.Credentials.from_service_account_file( + SERVICE_ACCOUNT_JSON_FILE, + subject=subject, + scopes=scopes, + additional_claims=additional_claims, + ) + + assert credentials.service_account_email == info["client_email"] + assert credentials.project_id == info["project_id"] + assert credentials._signer.key_id == info["private_key_id"] + assert credentials._token_uri == info["token_uri"] + assert credentials._scopes == scopes + assert credentials._subject == subject + assert credentials._additional_claims == additional_claims + + def test_default_state(self): + credentials = self.make_credentials() + assert not credentials.valid + # Expiration hasn't been set yet + assert not credentials.expired + # Scopes haven't been specified yet + assert credentials.requires_scopes + + def test_sign_bytes(self): + credentials = self.make_credentials() + to_sign = b"123" + signature = credentials.sign_bytes(to_sign) + assert crypt.verify_signature(to_sign, signature, PUBLIC_CERT_BYTES) + + def test_signer(self): + credentials = self.make_credentials() + assert isinstance(credentials.signer, crypt.Signer) + + def test_signer_email(self): + credentials = self.make_credentials() + assert credentials.signer_email == self.SERVICE_ACCOUNT_EMAIL + + def test_create_scoped(self): + credentials = self.make_credentials() + scopes = ["email", "profile"] + credentials = credentials.with_scopes(scopes) + assert credentials._scopes == scopes + + def test_with_claims(self): + credentials = self.make_credentials() + new_credentials = credentials.with_claims({"meep": "moop"}) + assert new_credentials._additional_claims == {"meep": "moop"} + + def test_with_quota_project(self): + credentials = self.make_credentials() + new_credentials = credentials.with_quota_project("new-project-456") + assert new_credentials.quota_project_id == "new-project-456" + hdrs = {} + new_credentials.apply(hdrs, token="tok") + assert "x-goog-user-project" in hdrs + + def test__make_authorization_grant_assertion(self): + credentials = self.make_credentials() + token = credentials._make_authorization_grant_assertion() + payload = jwt.decode(token, PUBLIC_CERT_BYTES) + assert payload["iss"] == self.SERVICE_ACCOUNT_EMAIL + assert payload["aud"] == self.TOKEN_URI + + def test__make_authorization_grant_assertion_scoped(self): + credentials = self.make_credentials() + scopes = ["email", "profile"] + credentials = credentials.with_scopes(scopes) + token = credentials._make_authorization_grant_assertion() + payload = jwt.decode(token, PUBLIC_CERT_BYTES) + assert payload["scope"] == "email profile" + + def test__make_authorization_grant_assertion_subject(self): + credentials = self.make_credentials() + subject = "user@example.com" + credentials = credentials.with_subject(subject) + token = credentials._make_authorization_grant_assertion() + payload = jwt.decode(token, PUBLIC_CERT_BYTES) + assert payload["sub"] == subject + + @mock.patch("google.oauth2._client_async.jwt_grant", autospec=True) + @pytest.mark.asyncio + async def test_refresh_success(self, jwt_grant): + credentials = self.make_credentials() + token = "token" + jwt_grant.return_value = ( + token, + _helpers.utcnow() + datetime.timedelta(seconds=500), + {}, + ) + request = mock.create_autospec(transport.Request, instance=True) + + # Refresh credentials + await credentials.refresh(request) + + # Check jwt grant call. + assert jwt_grant.called + + called_request, token_uri, assertion = jwt_grant.call_args[0] + assert called_request == request + assert token_uri == credentials._token_uri + assert jwt.decode(assertion, PUBLIC_CERT_BYTES) + # No further assertion done on the token, as there are separate tests + # for checking the authorization grant assertion. + + # Check that the credentials have the token. + assert credentials.token == token + + # Check that the credentials are valid (have a token and are not + # expired) + assert credentials.valid + + @mock.patch("google.oauth2._client_async.jwt_grant", autospec=True) + @pytest.mark.asyncio + async def test_before_request_refreshes(self, jwt_grant): + credentials = self.make_credentials() + token = "token" + jwt_grant.return_value = ( + token, + _helpers.utcnow() + datetime.timedelta(seconds=500), + None, + ) + request = mock.create_autospec(transport.Request, instance=True) + + # Credentials should start as invalid + assert not credentials.valid + + # before_request should cause a refresh + await credentials.before_request(request, "GET", "http://example.com?a=1#3", {}) + + # The refresh endpoint should've been called. + assert jwt_grant.called + + # Credentials should now be valid. + assert credentials.valid + + +class TestIDTokenCredentials(object): + SERVICE_ACCOUNT_EMAIL = "service-account@example.com" + TOKEN_URI = "https://example.com/oauth2/token" + TARGET_AUDIENCE = "https://example.com" + + @classmethod + def make_credentials(cls): + return service_account.IDTokenCredentials( + SIGNER, cls.SERVICE_ACCOUNT_EMAIL, cls.TOKEN_URI, cls.TARGET_AUDIENCE + ) + + def test_from_service_account_info(self): + credentials = service_account.IDTokenCredentials.from_service_account_info( + SERVICE_ACCOUNT_INFO, target_audience=self.TARGET_AUDIENCE + ) + + assert credentials._signer.key_id == SERVICE_ACCOUNT_INFO["private_key_id"] + assert credentials.service_account_email == SERVICE_ACCOUNT_INFO["client_email"] + assert credentials._token_uri == SERVICE_ACCOUNT_INFO["token_uri"] + assert credentials._target_audience == self.TARGET_AUDIENCE + + def test_from_service_account_file(self): + info = SERVICE_ACCOUNT_INFO.copy() + + credentials = service_account.IDTokenCredentials.from_service_account_file( + SERVICE_ACCOUNT_JSON_FILE, target_audience=self.TARGET_AUDIENCE + ) + + assert credentials.service_account_email == info["client_email"] + assert credentials._signer.key_id == info["private_key_id"] + assert credentials._token_uri == info["token_uri"] + assert credentials._target_audience == self.TARGET_AUDIENCE + + def test_default_state(self): + credentials = self.make_credentials() + assert not credentials.valid + # Expiration hasn't been set yet + assert not credentials.expired + + def test_sign_bytes(self): + credentials = self.make_credentials() + to_sign = b"123" + signature = credentials.sign_bytes(to_sign) + assert crypt.verify_signature(to_sign, signature, PUBLIC_CERT_BYTES) + + def test_signer(self): + credentials = self.make_credentials() + assert isinstance(credentials.signer, crypt.Signer) + + def test_signer_email(self): + credentials = self.make_credentials() + assert credentials.signer_email == self.SERVICE_ACCOUNT_EMAIL + + def test_with_target_audience(self): + credentials = self.make_credentials() + new_credentials = credentials.with_target_audience("https://new.example.com") + assert new_credentials._target_audience == "https://new.example.com" + + def test_with_quota_project(self): + credentials = self.make_credentials() + new_credentials = credentials.with_quota_project("project-foo") + assert new_credentials._quota_project_id == "project-foo" + + def test__make_authorization_grant_assertion(self): + credentials = self.make_credentials() + token = credentials._make_authorization_grant_assertion() + payload = jwt.decode(token, PUBLIC_CERT_BYTES) + assert payload["iss"] == self.SERVICE_ACCOUNT_EMAIL + assert payload["aud"] == self.TOKEN_URI + assert payload["target_audience"] == self.TARGET_AUDIENCE + + @mock.patch("google.oauth2._client_async.id_token_jwt_grant", autospec=True) + @pytest.mark.asyncio + async def test_refresh_success(self, id_token_jwt_grant): + credentials = self.make_credentials() + token = "token" + id_token_jwt_grant.return_value = ( + token, + _helpers.utcnow() + datetime.timedelta(seconds=500), + {}, + ) + + request = mock.AsyncMock(spec=["transport.Request"]) + + # Refresh credentials + await credentials.refresh(request) + + # Check jwt grant call. + assert id_token_jwt_grant.called + + called_request, token_uri, assertion = id_token_jwt_grant.call_args[0] + assert called_request == request + assert token_uri == credentials._token_uri + assert jwt.decode(assertion, PUBLIC_CERT_BYTES) + # No further assertion done on the token, as there are separate tests + # for checking the authorization grant assertion. + + # Check that the credentials have the token. + assert credentials.token == token + + # Check that the credentials are valid (have a token and are not + # expired) + assert credentials.valid + + @mock.patch("google.oauth2._client_async.id_token_jwt_grant", autospec=True) + @pytest.mark.asyncio + async def test_before_request_refreshes(self, id_token_jwt_grant): + credentials = self.make_credentials() + token = "token" + id_token_jwt_grant.return_value = ( + token, + _helpers.utcnow() + datetime.timedelta(seconds=500), + None, + ) + request = mock.AsyncMock(spec=["transport.Request"]) + + # Credentials should start as invalid + assert not credentials.valid + + # before_request should cause a refresh + await credentials.before_request(request, "GET", "http://example.com?a=1#3", {}) + + # The refresh endpoint should've been called. + assert id_token_jwt_grant.called + + # Credentials should now be valid. + assert credentials.valid diff --git a/tests_async/test__default.py b/tests_async/test__default.py new file mode 100644 index 000000000..d41edf76c --- /dev/null +++ b/tests_async/test__default.py @@ -0,0 +1,484 @@ +# Copyright 2016 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import os + +import mock +import pytest + +from google.auth import _default_async as _default +from google.auth import app_engine +from google.auth import compute_engine +from google.auth import credentials_async as credentials +from google.auth import environment_vars +from google.auth import exceptions +from google.oauth2 import service_account_async as service_account +import google.oauth2.credentials + + +DATA_DIR = os.path.join(os.path.abspath(os.path.join(__file__, "../..")), "tests/data") +AUTHORIZED_USER_FILE = os.path.join(DATA_DIR, "authorized_user.json") + +with open(AUTHORIZED_USER_FILE) as fh: + AUTHORIZED_USER_FILE_DATA = json.load(fh) + +AUTHORIZED_USER_CLOUD_SDK_FILE = os.path.join( + DATA_DIR, "authorized_user_cloud_sdk.json" +) + +AUTHORIZED_USER_CLOUD_SDK_WITH_QUOTA_PROJECT_ID_FILE = os.path.join( + DATA_DIR, "authorized_user_cloud_sdk_with_quota_project_id.json" +) + +SERVICE_ACCOUNT_FILE = os.path.join(DATA_DIR, "service_account.json") + +CLIENT_SECRETS_FILE = os.path.join(DATA_DIR, "client_secrets.json") + +with open(SERVICE_ACCOUNT_FILE) as fh: + SERVICE_ACCOUNT_FILE_DATA = json.load(fh) + +MOCK_CREDENTIALS = mock.Mock(spec=credentials.Credentials) +MOCK_CREDENTIALS.with_quota_project.return_value = MOCK_CREDENTIALS + +LOAD_FILE_PATCH = mock.patch( + "google.auth._default_async.load_credentials_from_file", + return_value=(MOCK_CREDENTIALS, mock.sentinel.project_id), + autospec=True, +) + + +def test_load_credentials_from_missing_file(): + with pytest.raises(exceptions.DefaultCredentialsError) as excinfo: + _default.load_credentials_from_file("") + + assert excinfo.match(r"not found") + + +def test_load_credentials_from_file_invalid_json(tmpdir): + jsonfile = tmpdir.join("invalid.json") + jsonfile.write("{") + + with pytest.raises(exceptions.DefaultCredentialsError) as excinfo: + _default.load_credentials_from_file(str(jsonfile)) + + assert excinfo.match(r"not a valid json file") + + +def test_load_credentials_from_file_invalid_type(tmpdir): + jsonfile = tmpdir.join("invalid.json") + jsonfile.write(json.dumps({"type": "not-a-real-type"})) + + with pytest.raises(exceptions.DefaultCredentialsError) as excinfo: + _default.load_credentials_from_file(str(jsonfile)) + + assert excinfo.match(r"does not have a valid type") + + +def test_load_credentials_from_file_authorized_user(): + credentials, project_id = _default.load_credentials_from_file(AUTHORIZED_USER_FILE) + assert isinstance(credentials, google.oauth2.credentials_async.Credentials) + assert project_id is None + + +def test_load_credentials_from_file_no_type(tmpdir): + # use the client_secrets.json, which is valid json but not a + # loadable credentials type + with pytest.raises(exceptions.DefaultCredentialsError) as excinfo: + _default.load_credentials_from_file(CLIENT_SECRETS_FILE) + + assert excinfo.match(r"does not have a valid type") + assert excinfo.match(r"Type is None") + + +def test_load_credentials_from_file_authorized_user_bad_format(tmpdir): + filename = tmpdir.join("authorized_user_bad.json") + filename.write(json.dumps({"type": "authorized_user"})) + + with pytest.raises(exceptions.DefaultCredentialsError) as excinfo: + _default.load_credentials_from_file(str(filename)) + + assert excinfo.match(r"Failed to load authorized user") + assert excinfo.match(r"missing fields") + + +def test_load_credentials_from_file_authorized_user_cloud_sdk(): + with pytest.warns(UserWarning, match="Cloud SDK"): + credentials, project_id = _default.load_credentials_from_file( + AUTHORIZED_USER_CLOUD_SDK_FILE + ) + assert isinstance(credentials, google.oauth2.credentials_async.Credentials) + assert project_id is None + + # No warning if the json file has quota project id. + credentials, project_id = _default.load_credentials_from_file( + AUTHORIZED_USER_CLOUD_SDK_WITH_QUOTA_PROJECT_ID_FILE + ) + assert isinstance(credentials, google.oauth2.credentials_async.Credentials) + assert project_id is None + + +def test_load_credentials_from_file_authorized_user_cloud_sdk_with_scopes(): + with pytest.warns(UserWarning, match="Cloud SDK"): + credentials, project_id = _default.load_credentials_from_file( + AUTHORIZED_USER_CLOUD_SDK_FILE, + scopes=["https://www.google.com/calendar/feeds"], + ) + assert isinstance(credentials, google.oauth2.credentials_async.Credentials) + assert project_id is None + assert credentials.scopes == ["https://www.google.com/calendar/feeds"] + + +def test_load_credentials_from_file_authorized_user_cloud_sdk_with_quota_project(): + credentials, project_id = _default.load_credentials_from_file( + AUTHORIZED_USER_CLOUD_SDK_FILE, quota_project_id="project-foo" + ) + + assert isinstance(credentials, google.oauth2.credentials_async.Credentials) + assert project_id is None + assert credentials.quota_project_id == "project-foo" + + +def test_load_credentials_from_file_service_account(): + credentials, project_id = _default.load_credentials_from_file(SERVICE_ACCOUNT_FILE) + assert isinstance(credentials, service_account.Credentials) + assert project_id == SERVICE_ACCOUNT_FILE_DATA["project_id"] + + +def test_load_credentials_from_file_service_account_with_scopes(): + credentials, project_id = _default.load_credentials_from_file( + SERVICE_ACCOUNT_FILE, scopes=["https://www.google.com/calendar/feeds"] + ) + assert isinstance(credentials, service_account.Credentials) + assert project_id == SERVICE_ACCOUNT_FILE_DATA["project_id"] + assert credentials.scopes == ["https://www.google.com/calendar/feeds"] + + +def test_load_credentials_from_file_service_account_bad_format(tmpdir): + filename = tmpdir.join("serivce_account_bad.json") + filename.write(json.dumps({"type": "service_account"})) + + with pytest.raises(exceptions.DefaultCredentialsError) as excinfo: + _default.load_credentials_from_file(str(filename)) + + assert excinfo.match(r"Failed to load service account") + assert excinfo.match(r"missing fields") + + +@mock.patch.dict(os.environ, {}, clear=True) +def test__get_explicit_environ_credentials_no_env(): + assert _default._get_explicit_environ_credentials() == (None, None) + + +@LOAD_FILE_PATCH +def test__get_explicit_environ_credentials(load, monkeypatch): + monkeypatch.setenv(environment_vars.CREDENTIALS, "filename") + + credentials, project_id = _default._get_explicit_environ_credentials() + + assert credentials is MOCK_CREDENTIALS + assert project_id is mock.sentinel.project_id + load.assert_called_with("filename") + + +@LOAD_FILE_PATCH +def test__get_explicit_environ_credentials_no_project_id(load, monkeypatch): + load.return_value = MOCK_CREDENTIALS, None + monkeypatch.setenv(environment_vars.CREDENTIALS, "filename") + + credentials, project_id = _default._get_explicit_environ_credentials() + + assert credentials is MOCK_CREDENTIALS + assert project_id is None + + +@LOAD_FILE_PATCH +@mock.patch( + "google.auth._cloud_sdk.get_application_default_credentials_path", autospec=True +) +def test__get_gcloud_sdk_credentials(get_adc_path, load): + get_adc_path.return_value = SERVICE_ACCOUNT_FILE + + credentials, project_id = _default._get_gcloud_sdk_credentials() + + assert credentials is MOCK_CREDENTIALS + assert project_id is mock.sentinel.project_id + load.assert_called_with(SERVICE_ACCOUNT_FILE) + + +@mock.patch( + "google.auth._cloud_sdk.get_application_default_credentials_path", autospec=True +) +def test__get_gcloud_sdk_credentials_non_existent(get_adc_path, tmpdir): + non_existent = tmpdir.join("non-existent") + get_adc_path.return_value = str(non_existent) + + credentials, project_id = _default._get_gcloud_sdk_credentials() + + assert credentials is None + assert project_id is None + + +@mock.patch( + "google.auth._cloud_sdk.get_project_id", + return_value=mock.sentinel.project_id, + autospec=True, +) +@mock.patch("os.path.isfile", return_value=True, autospec=True) +@LOAD_FILE_PATCH +def test__get_gcloud_sdk_credentials_project_id(load, unused_isfile, get_project_id): + # Don't return a project ID from load file, make the function check + # the Cloud SDK project. + load.return_value = MOCK_CREDENTIALS, None + + credentials, project_id = _default._get_gcloud_sdk_credentials() + + assert credentials == MOCK_CREDENTIALS + assert project_id == mock.sentinel.project_id + assert get_project_id.called + + +@mock.patch("google.auth._cloud_sdk.get_project_id", return_value=None, autospec=True) +@mock.patch("os.path.isfile", return_value=True) +@LOAD_FILE_PATCH +def test__get_gcloud_sdk_credentials_no_project_id(load, unused_isfile, get_project_id): + # Don't return a project ID from load file, make the function check + # the Cloud SDK project. + load.return_value = MOCK_CREDENTIALS, None + + credentials, project_id = _default._get_gcloud_sdk_credentials() + + assert credentials == MOCK_CREDENTIALS + assert project_id is None + assert get_project_id.called + + +class _AppIdentityModule(object): + """The interface of the App Idenity app engine module. + See https://cloud.google.com/appengine/docs/standard/python/refdocs\ + /google.appengine.api.app_identity.app_identity + """ + + def get_application_id(self): + raise NotImplementedError() + + +@pytest.fixture +def app_identity(monkeypatch): + """Mocks the app_identity module for google.auth.app_engine.""" + app_identity_module = mock.create_autospec(_AppIdentityModule, instance=True) + monkeypatch.setattr(app_engine, "app_identity", app_identity_module) + yield app_identity_module + + +def test__get_gae_credentials(app_identity): + app_identity.get_application_id.return_value = mock.sentinel.project + + credentials, project_id = _default._get_gae_credentials() + + assert isinstance(credentials, app_engine.Credentials) + assert project_id == mock.sentinel.project + + +def test__get_gae_credentials_no_app_engine(): + import sys + + with mock.patch.dict("sys.modules"): + sys.modules["google.auth.app_engine"] = None + credentials, project_id = _default._get_gae_credentials() + assert credentials is None + assert project_id is None + + +def test__get_gae_credentials_no_apis(): + assert _default._get_gae_credentials() == (None, None) + + +@mock.patch( + "google.auth.compute_engine._metadata.ping", return_value=True, autospec=True +) +@mock.patch( + "google.auth.compute_engine._metadata.get_project_id", + return_value="example-project", + autospec=True, +) +def test__get_gce_credentials(unused_get, unused_ping): + credentials, project_id = _default._get_gce_credentials() + + assert isinstance(credentials, compute_engine.Credentials) + assert project_id == "example-project" + + +@mock.patch( + "google.auth.compute_engine._metadata.ping", return_value=False, autospec=True +) +def test__get_gce_credentials_no_ping(unused_ping): + credentials, project_id = _default._get_gce_credentials() + + assert credentials is None + assert project_id is None + + +@mock.patch( + "google.auth.compute_engine._metadata.ping", return_value=True, autospec=True +) +@mock.patch( + "google.auth.compute_engine._metadata.get_project_id", + side_effect=exceptions.TransportError(), + autospec=True, +) +def test__get_gce_credentials_no_project_id(unused_get, unused_ping): + credentials, project_id = _default._get_gce_credentials() + + assert isinstance(credentials, compute_engine.Credentials) + assert project_id is None + + +def test__get_gce_credentials_no_compute_engine(): + import sys + + with mock.patch.dict("sys.modules"): + sys.modules["google.auth.compute_engine"] = None + credentials, project_id = _default._get_gce_credentials() + assert credentials is None + assert project_id is None + + +@mock.patch( + "google.auth.compute_engine._metadata.ping", return_value=False, autospec=True +) +def test__get_gce_credentials_explicit_request(ping): + _default._get_gce_credentials(mock.sentinel.request) + ping.assert_called_with(request=mock.sentinel.request) + + +@mock.patch( + "google.auth._default_async._get_explicit_environ_credentials", + return_value=(MOCK_CREDENTIALS, mock.sentinel.project_id), + autospec=True, +) +def test_default_early_out(unused_get): + assert _default.default_async() == (MOCK_CREDENTIALS, mock.sentinel.project_id) + + +@mock.patch( + "google.auth._default_async._get_explicit_environ_credentials", + return_value=(MOCK_CREDENTIALS, mock.sentinel.project_id), + autospec=True, +) +def test_default_explict_project_id(unused_get, monkeypatch): + monkeypatch.setenv(environment_vars.PROJECT, "explicit-env") + assert _default.default_async() == (MOCK_CREDENTIALS, "explicit-env") + + +@mock.patch( + "google.auth._default_async._get_explicit_environ_credentials", + return_value=(MOCK_CREDENTIALS, mock.sentinel.project_id), + autospec=True, +) +def test_default_explict_legacy_project_id(unused_get, monkeypatch): + monkeypatch.setenv(environment_vars.LEGACY_PROJECT, "explicit-env") + assert _default.default_async() == (MOCK_CREDENTIALS, "explicit-env") + + +@mock.patch("logging.Logger.warning", autospec=True) +@mock.patch( + "google.auth._default_async._get_explicit_environ_credentials", + return_value=(MOCK_CREDENTIALS, None), + autospec=True, +) +@mock.patch( + "google.auth._default_async._get_gcloud_sdk_credentials", + return_value=(MOCK_CREDENTIALS, None), + autospec=True, +) +@mock.patch( + "google.auth._default_async._get_gae_credentials", + return_value=(MOCK_CREDENTIALS, None), + autospec=True, +) +@mock.patch( + "google.auth._default_async._get_gce_credentials", + return_value=(MOCK_CREDENTIALS, None), + autospec=True, +) +def test_default_without_project_id( + unused_gce, unused_gae, unused_sdk, unused_explicit, logger_warning +): + assert _default.default_async() == (MOCK_CREDENTIALS, None) + logger_warning.assert_called_with(mock.ANY, mock.ANY, mock.ANY) + + +@mock.patch( + "google.auth._default_async._get_explicit_environ_credentials", + return_value=(None, None), + autospec=True, +) +@mock.patch( + "google.auth._default_async._get_gcloud_sdk_credentials", + return_value=(None, None), + autospec=True, +) +@mock.patch( + "google.auth._default_async._get_gae_credentials", + return_value=(None, None), + autospec=True, +) +@mock.patch( + "google.auth._default_async._get_gce_credentials", + return_value=(None, None), + autospec=True, +) +def test_default_fail(unused_gce, unused_gae, unused_sdk, unused_explicit): + with pytest.raises(exceptions.DefaultCredentialsError): + assert _default.default_async() + + +@mock.patch( + "google.auth._default_async._get_explicit_environ_credentials", + return_value=(MOCK_CREDENTIALS, mock.sentinel.project_id), + autospec=True, +) +@mock.patch( + "google.auth.credentials_async.with_scopes_if_required", + return_value=MOCK_CREDENTIALS, + autospec=True, +) +def test_default_scoped(with_scopes, unused_get): + scopes = ["one", "two"] + + credentials, project_id = _default.default_async(scopes=scopes) + + assert credentials == with_scopes.return_value + assert project_id == mock.sentinel.project_id + with_scopes.assert_called_once_with(MOCK_CREDENTIALS, scopes) + + +@mock.patch( + "google.auth._default_async._get_explicit_environ_credentials", + return_value=(MOCK_CREDENTIALS, mock.sentinel.project_id), + autospec=True, +) +def test_default_no_app_engine_compute_engine_module(unused_get): + """ + google.auth.compute_engine and google.auth.app_engine are both optional + to allow not including them when using this package. This verifies + that default fails gracefully if these modules are absent + """ + import sys + + with mock.patch.dict("sys.modules"): + sys.modules["google.auth.compute_engine"] = None + sys.modules["google.auth.app_engine"] = None + assert _default.default_async() == (MOCK_CREDENTIALS, mock.sentinel.project_id) diff --git a/tests_async/test__oauth2client.py b/tests_async/test__oauth2client.py new file mode 100644 index 000000000..ce0b7defc --- /dev/null +++ b/tests_async/test__oauth2client.py @@ -0,0 +1,170 @@ +# Copyright 2016 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import datetime +import os +import sys + +import mock +import oauth2client.client +import oauth2client.contrib.gce +import oauth2client.service_account +import pytest +from six.moves import reload_module + +from google.auth import _oauth2client_async as _oauth2client + + +DATA_DIR = os.path.join(os.path.abspath(os.path.join(__file__, "../..")), "tests/data") +SERVICE_ACCOUNT_JSON_FILE = os.path.join(DATA_DIR, "service_account.json") + + +def test__convert_oauth2_credentials(): + old_credentials = oauth2client.client.OAuth2Credentials( + "access_token", + "client_id", + "client_secret", + "refresh_token", + datetime.datetime.min, + "token_uri", + "user_agent", + scopes="one two", + ) + + new_credentials = _oauth2client._convert_oauth2_credentials(old_credentials) + + assert new_credentials.token == old_credentials.access_token + assert new_credentials._refresh_token == old_credentials.refresh_token + assert new_credentials._client_id == old_credentials.client_id + assert new_credentials._client_secret == old_credentials.client_secret + assert new_credentials._token_uri == old_credentials.token_uri + assert new_credentials.scopes == old_credentials.scopes + + +def test__convert_service_account_credentials(): + old_class = oauth2client.service_account.ServiceAccountCredentials + old_credentials = old_class.from_json_keyfile_name(SERVICE_ACCOUNT_JSON_FILE) + + new_credentials = _oauth2client._convert_service_account_credentials( + old_credentials + ) + + assert ( + new_credentials.service_account_email == old_credentials.service_account_email + ) + assert new_credentials._signer.key_id == old_credentials._private_key_id + assert new_credentials._token_uri == old_credentials.token_uri + + +def test__convert_service_account_credentials_with_jwt(): + old_class = oauth2client.service_account._JWTAccessCredentials + old_credentials = old_class.from_json_keyfile_name(SERVICE_ACCOUNT_JSON_FILE) + + new_credentials = _oauth2client._convert_service_account_credentials( + old_credentials + ) + + assert ( + new_credentials.service_account_email == old_credentials.service_account_email + ) + assert new_credentials._signer.key_id == old_credentials._private_key_id + assert new_credentials._token_uri == old_credentials.token_uri + + +def test__convert_gce_app_assertion_credentials(): + old_credentials = oauth2client.contrib.gce.AppAssertionCredentials( + email="some_email" + ) + + new_credentials = _oauth2client._convert_gce_app_assertion_credentials( + old_credentials + ) + + assert ( + new_credentials.service_account_email == old_credentials.service_account_email + ) + + +@pytest.fixture +def mock_oauth2client_gae_imports(mock_non_existent_module): + mock_non_existent_module("google.appengine.api.app_identity") + mock_non_existent_module("google.appengine.ext.ndb") + mock_non_existent_module("google.appengine.ext.webapp.util") + mock_non_existent_module("webapp2") + + +@mock.patch("google.auth.app_engine.app_identity") +def test__convert_appengine_app_assertion_credentials( + app_identity, mock_oauth2client_gae_imports +): + + import oauth2client.contrib.appengine + + service_account_id = "service_account_id" + old_credentials = oauth2client.contrib.appengine.AppAssertionCredentials( + scope="one two", service_account_id=service_account_id + ) + + new_credentials = _oauth2client._convert_appengine_app_assertion_credentials( + old_credentials + ) + + assert new_credentials.scopes == ["one", "two"] + assert new_credentials._service_account_id == old_credentials.service_account_id + + +class FakeCredentials(object): + pass + + +def test_convert_success(): + convert_function = mock.Mock(spec=["__call__"]) + conversion_map_patch = mock.patch.object( + _oauth2client, "_CLASS_CONVERSION_MAP", {FakeCredentials: convert_function} + ) + credentials = FakeCredentials() + + with conversion_map_patch: + result = _oauth2client.convert(credentials) + + convert_function.assert_called_once_with(credentials) + assert result == convert_function.return_value + + +def test_convert_not_found(): + with pytest.raises(ValueError) as excinfo: + _oauth2client.convert("a string is not a real credentials class") + + assert excinfo.match("Unable to convert") + + +@pytest.fixture +def reset__oauth2client_module(): + """Reloads the _oauth2client module after a test.""" + reload_module(_oauth2client) + + +def test_import_has_app_engine( + mock_oauth2client_gae_imports, reset__oauth2client_module +): + reload_module(_oauth2client) + assert _oauth2client._HAS_APPENGINE + + +def test_import_without_oauth2client(monkeypatch, reset__oauth2client_module): + monkeypatch.setitem(sys.modules, "oauth2client", None) + with pytest.raises(ImportError) as excinfo: + reload_module(_oauth2client) + + assert excinfo.match("oauth2client") diff --git a/tests_async/test_jwt.py b/tests_async/test_jwt.py new file mode 100644 index 000000000..c3720b075 --- /dev/null +++ b/tests_async/test_jwt.py @@ -0,0 +1,583 @@ +# Copyright 2014 Google Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import base64 +import datetime +import json +import os + +import mock +import pytest + +from google.auth import _helpers +from google.auth import crypt +from google.auth import exceptions +from google.auth import jwt_async as jwt + + +DATA_DIR = os.path.join(os.path.abspath(os.path.join(__file__, "../..")), "tests/data") + +with open(os.path.join(DATA_DIR, "privatekey.pem"), "rb") as fh: + PRIVATE_KEY_BYTES = fh.read() + +with open(os.path.join(DATA_DIR, "public_cert.pem"), "rb") as fh: + PUBLIC_CERT_BYTES = fh.read() + +with open(os.path.join(DATA_DIR, "other_cert.pem"), "rb") as fh: + OTHER_CERT_BYTES = fh.read() + +with open(os.path.join(DATA_DIR, "es256_privatekey.pem"), "rb") as fh: + EC_PRIVATE_KEY_BYTES = fh.read() + +with open(os.path.join(DATA_DIR, "es256_public_cert.pem"), "rb") as fh: + EC_PUBLIC_CERT_BYTES = fh.read() + +SERVICE_ACCOUNT_JSON_FILE = os.path.join(DATA_DIR, "service_account.json") + +with open(SERVICE_ACCOUNT_JSON_FILE, "r") as fh: + SERVICE_ACCOUNT_INFO = json.load(fh) + + +@pytest.fixture +def signer(): + return crypt.RSASigner.from_string(PRIVATE_KEY_BYTES, "1") + + +def test_encode_basic(signer): + test_payload = {"test": "value"} + encoded = jwt.encode(signer, test_payload) + header, payload, _, _ = jwt._unverified_decode(encoded) + assert payload == test_payload + assert header == {"typ": "JWT", "alg": "RS256", "kid": signer.key_id} + + +def test_encode_extra_headers(signer): + encoded = jwt.encode(signer, {}, header={"extra": "value"}) + header = jwt.decode_header(encoded) + assert header == { + "typ": "JWT", + "alg": "RS256", + "kid": signer.key_id, + "extra": "value", + } + + +@pytest.fixture +def es256_signer(): + return crypt.ES256Signer.from_string(EC_PRIVATE_KEY_BYTES, "1") + + +def test_encode_basic_es256(es256_signer): + test_payload = {"test": "value"} + encoded = jwt.encode(es256_signer, test_payload) + header, payload, _, _ = jwt._unverified_decode(encoded) + assert payload == test_payload + assert header == {"typ": "JWT", "alg": "ES256", "kid": es256_signer.key_id} + + +@pytest.fixture +def token_factory(signer, es256_signer): + def factory(claims=None, key_id=None, use_es256_signer=False): + now = _helpers.datetime_to_secs(_helpers.utcnow()) + payload = { + "aud": "audience@example.com", + "iat": now, + "exp": now + 300, + "user": "billy bob", + "metadata": {"meta": "data"}, + } + payload.update(claims or {}) + + # False is specified to remove the signer's key id for testing + # headers without key ids. + if key_id is False: + signer._key_id = None + key_id = None + + if use_es256_signer: + return jwt.encode(es256_signer, payload, key_id=key_id) + else: + return jwt.encode(signer, payload, key_id=key_id) + + return factory + + +def test_decode_valid(token_factory): + payload = jwt.decode(token_factory(), certs=PUBLIC_CERT_BYTES) + assert payload["aud"] == "audience@example.com" + assert payload["user"] == "billy bob" + assert payload["metadata"]["meta"] == "data" + + +def test_decode_valid_es256(token_factory): + payload = jwt.decode( + token_factory(use_es256_signer=True), certs=EC_PUBLIC_CERT_BYTES + ) + assert payload["aud"] == "audience@example.com" + assert payload["user"] == "billy bob" + assert payload["metadata"]["meta"] == "data" + + +def test_decode_valid_with_audience(token_factory): + payload = jwt.decode( + token_factory(), certs=PUBLIC_CERT_BYTES, audience="audience@example.com" + ) + assert payload["aud"] == "audience@example.com" + assert payload["user"] == "billy bob" + assert payload["metadata"]["meta"] == "data" + + +def test_decode_valid_unverified(token_factory): + payload = jwt.decode(token_factory(), certs=OTHER_CERT_BYTES, verify=False) + assert payload["aud"] == "audience@example.com" + assert payload["user"] == "billy bob" + assert payload["metadata"]["meta"] == "data" + + +def test_decode_bad_token_wrong_number_of_segments(): + with pytest.raises(ValueError) as excinfo: + jwt.decode("1.2", PUBLIC_CERT_BYTES) + assert excinfo.match(r"Wrong number of segments") + + +def test_decode_bad_token_not_base64(): + with pytest.raises((ValueError, TypeError)) as excinfo: + jwt.decode("1.2.3", PUBLIC_CERT_BYTES) + assert excinfo.match(r"Incorrect padding|more than a multiple of 4") + + +def test_decode_bad_token_not_json(): + token = b".".join([base64.urlsafe_b64encode(b"123!")] * 3) + with pytest.raises(ValueError) as excinfo: + jwt.decode(token, PUBLIC_CERT_BYTES) + assert excinfo.match(r"Can\'t parse segment") + + +def test_decode_bad_token_no_iat_or_exp(signer): + token = jwt.encode(signer, {"test": "value"}) + with pytest.raises(ValueError) as excinfo: + jwt.decode(token, PUBLIC_CERT_BYTES) + assert excinfo.match(r"Token does not contain required claim") + + +def test_decode_bad_token_too_early(token_factory): + token = token_factory( + claims={ + "iat": _helpers.datetime_to_secs( + _helpers.utcnow() + datetime.timedelta(hours=1) + ) + } + ) + with pytest.raises(ValueError) as excinfo: + jwt.decode(token, PUBLIC_CERT_BYTES) + assert excinfo.match(r"Token used too early") + + +def test_decode_bad_token_expired(token_factory): + token = token_factory( + claims={ + "exp": _helpers.datetime_to_secs( + _helpers.utcnow() - datetime.timedelta(hours=1) + ) + } + ) + with pytest.raises(ValueError) as excinfo: + jwt.decode(token, PUBLIC_CERT_BYTES) + assert excinfo.match(r"Token expired") + + +def test_decode_bad_token_wrong_audience(token_factory): + token = token_factory() + audience = "audience2@example.com" + with pytest.raises(ValueError) as excinfo: + jwt.decode(token, PUBLIC_CERT_BYTES, audience=audience) + assert excinfo.match(r"Token has wrong audience") + + +def test_decode_wrong_cert(token_factory): + with pytest.raises(ValueError) as excinfo: + jwt.decode(token_factory(), OTHER_CERT_BYTES) + assert excinfo.match(r"Could not verify token signature") + + +def test_decode_multicert_bad_cert(token_factory): + certs = {"1": OTHER_CERT_BYTES, "2": PUBLIC_CERT_BYTES} + with pytest.raises(ValueError) as excinfo: + jwt.decode(token_factory(), certs) + assert excinfo.match(r"Could not verify token signature") + + +def test_decode_no_cert(token_factory): + certs = {"2": PUBLIC_CERT_BYTES} + with pytest.raises(ValueError) as excinfo: + jwt.decode(token_factory(), certs) + assert excinfo.match(r"Certificate for key id 1 not found") + + +def test_decode_no_key_id(token_factory): + token = token_factory(key_id=False) + certs = {"2": PUBLIC_CERT_BYTES} + payload = jwt.decode(token, certs) + assert payload["user"] == "billy bob" + + +def test_decode_unknown_alg(): + headers = json.dumps({u"kid": u"1", u"alg": u"fakealg"}) + token = b".".join( + map(lambda seg: base64.b64encode(seg.encode("utf-8")), [headers, u"{}", u"sig"]) + ) + + with pytest.raises(ValueError) as excinfo: + jwt.decode(token) + assert excinfo.match(r"fakealg") + + +def test_decode_missing_crytography_alg(monkeypatch): + monkeypatch.delitem(jwt._ALGORITHM_TO_VERIFIER_CLASS, "ES256") + headers = json.dumps({u"kid": u"1", u"alg": u"ES256"}) + token = b".".join( + map(lambda seg: base64.b64encode(seg.encode("utf-8")), [headers, u"{}", u"sig"]) + ) + + with pytest.raises(ValueError) as excinfo: + jwt.decode(token) + assert excinfo.match(r"cryptography") + + +def test_roundtrip_explicit_key_id(token_factory): + token = token_factory(key_id="3") + certs = {"2": OTHER_CERT_BYTES, "3": PUBLIC_CERT_BYTES} + payload = jwt.decode(token, certs) + assert payload["user"] == "billy bob" + + +class TestCredentials(object): + SERVICE_ACCOUNT_EMAIL = "service-account@example.com" + SUBJECT = "subject" + AUDIENCE = "audience" + ADDITIONAL_CLAIMS = {"meta": "data"} + credentials = None + + @pytest.fixture(autouse=True) + def credentials_fixture(self, signer): + self.credentials = jwt.Credentials( + signer, + self.SERVICE_ACCOUNT_EMAIL, + self.SERVICE_ACCOUNT_EMAIL, + self.AUDIENCE, + ) + + def test_from_service_account_info(self): + with open(SERVICE_ACCOUNT_JSON_FILE, "r") as fh: + info = json.load(fh) + + credentials = jwt.Credentials.from_service_account_info( + info, audience=self.AUDIENCE + ) + + assert credentials._signer.key_id == info["private_key_id"] + assert credentials._issuer == info["client_email"] + assert credentials._subject == info["client_email"] + assert credentials._audience == self.AUDIENCE + + def test_from_service_account_info_args(self): + info = SERVICE_ACCOUNT_INFO.copy() + + credentials = jwt.Credentials.from_service_account_info( + info, + subject=self.SUBJECT, + audience=self.AUDIENCE, + additional_claims=self.ADDITIONAL_CLAIMS, + ) + + assert credentials._signer.key_id == info["private_key_id"] + assert credentials._issuer == info["client_email"] + assert credentials._subject == self.SUBJECT + assert credentials._audience == self.AUDIENCE + assert credentials._additional_claims == self.ADDITIONAL_CLAIMS + + def test_from_service_account_file(self): + info = SERVICE_ACCOUNT_INFO.copy() + + credentials = jwt.Credentials.from_service_account_file( + SERVICE_ACCOUNT_JSON_FILE, audience=self.AUDIENCE + ) + + assert credentials._signer.key_id == info["private_key_id"] + assert credentials._issuer == info["client_email"] + assert credentials._subject == info["client_email"] + assert credentials._audience == self.AUDIENCE + + def test_from_service_account_file_args(self): + info = SERVICE_ACCOUNT_INFO.copy() + + credentials = jwt.Credentials.from_service_account_file( + SERVICE_ACCOUNT_JSON_FILE, + subject=self.SUBJECT, + audience=self.AUDIENCE, + additional_claims=self.ADDITIONAL_CLAIMS, + ) + + assert credentials._signer.key_id == info["private_key_id"] + assert credentials._issuer == info["client_email"] + assert credentials._subject == self.SUBJECT + assert credentials._audience == self.AUDIENCE + assert credentials._additional_claims == self.ADDITIONAL_CLAIMS + + def test_from_signing_credentials(self): + jwt_from_signing = self.credentials.from_signing_credentials( + self.credentials, audience=mock.sentinel.new_audience + ) + jwt_from_info = jwt.Credentials.from_service_account_info( + SERVICE_ACCOUNT_INFO, audience=mock.sentinel.new_audience + ) + + assert isinstance(jwt_from_signing, jwt.Credentials) + assert jwt_from_signing._signer.key_id == jwt_from_info._signer.key_id + assert jwt_from_signing._issuer == jwt_from_info._issuer + assert jwt_from_signing._subject == jwt_from_info._subject + assert jwt_from_signing._audience == jwt_from_info._audience + + def test_default_state(self): + assert not self.credentials.valid + # Expiration hasn't been set yet + assert not self.credentials.expired + + def test_with_claims(self): + new_audience = "new_audience" + new_credentials = self.credentials.with_claims(audience=new_audience) + + assert new_credentials._signer == self.credentials._signer + assert new_credentials._issuer == self.credentials._issuer + assert new_credentials._subject == self.credentials._subject + assert new_credentials._audience == new_audience + assert new_credentials._additional_claims == self.credentials._additional_claims + assert new_credentials._quota_project_id == self.credentials._quota_project_id + + def test_with_quota_project(self): + quota_project_id = "project-foo" + + new_credentials = self.credentials.with_quota_project(quota_project_id) + assert new_credentials._signer == self.credentials._signer + assert new_credentials._issuer == self.credentials._issuer + assert new_credentials._subject == self.credentials._subject + assert new_credentials._audience == self.credentials._audience + assert new_credentials._additional_claims == self.credentials._additional_claims + assert new_credentials._quota_project_id == quota_project_id + + def test_sign_bytes(self): + to_sign = b"123" + signature = self.credentials.sign_bytes(to_sign) + assert crypt.verify_signature(to_sign, signature, PUBLIC_CERT_BYTES) + + def test_signer(self): + assert isinstance(self.credentials.signer, crypt.RSASigner) + + def test_signer_email(self): + assert self.credentials.signer_email == SERVICE_ACCOUNT_INFO["client_email"] + + def _verify_token(self, token): + payload = jwt.decode(token, PUBLIC_CERT_BYTES) + assert payload["iss"] == self.SERVICE_ACCOUNT_EMAIL + return payload + + def test_refresh(self): + self.credentials.refresh(None) + assert self.credentials.valid + assert not self.credentials.expired + + def test_expired(self): + assert not self.credentials.expired + + self.credentials.refresh(None) + assert not self.credentials.expired + + with mock.patch("google.auth._helpers.utcnow") as now: + one_day = datetime.timedelta(days=1) + now.return_value = self.credentials.expiry + one_day + assert self.credentials.expired + + @pytest.mark.asyncio + async def test_before_request(self): + headers = {} + + self.credentials.refresh(None) + await self.credentials.before_request( + None, "GET", "http://example.com?a=1#3", headers + ) + + header_value = headers["authorization"] + _, token = header_value.split(" ") + + # Since the audience is set, it should use the existing token. + assert token.encode("utf-8") == self.credentials.token + + payload = self._verify_token(token) + assert payload["aud"] == self.AUDIENCE + + @pytest.mark.asyncio + async def test_before_request_refreshes(self): + assert not self.credentials.valid + await self.credentials.before_request( + None, "GET", "http://example.com?a=1#3", {} + ) + assert self.credentials.valid + + +class TestOnDemandCredentials(object): + SERVICE_ACCOUNT_EMAIL = "service-account@example.com" + SUBJECT = "subject" + ADDITIONAL_CLAIMS = {"meta": "data"} + credentials = None + + @pytest.fixture(autouse=True) + def credentials_fixture(self, signer): + self.credentials = jwt.OnDemandCredentials( + signer, + self.SERVICE_ACCOUNT_EMAIL, + self.SERVICE_ACCOUNT_EMAIL, + max_cache_size=2, + ) + + def test_from_service_account_info(self): + with open(SERVICE_ACCOUNT_JSON_FILE, "r") as fh: + info = json.load(fh) + + credentials = jwt.OnDemandCredentials.from_service_account_info(info) + + assert credentials._signer.key_id == info["private_key_id"] + assert credentials._issuer == info["client_email"] + assert credentials._subject == info["client_email"] + + def test_from_service_account_info_args(self): + info = SERVICE_ACCOUNT_INFO.copy() + + credentials = jwt.OnDemandCredentials.from_service_account_info( + info, subject=self.SUBJECT, additional_claims=self.ADDITIONAL_CLAIMS + ) + + assert credentials._signer.key_id == info["private_key_id"] + assert credentials._issuer == info["client_email"] + assert credentials._subject == self.SUBJECT + assert credentials._additional_claims == self.ADDITIONAL_CLAIMS + + def test_from_service_account_file(self): + info = SERVICE_ACCOUNT_INFO.copy() + + credentials = jwt.OnDemandCredentials.from_service_account_file( + SERVICE_ACCOUNT_JSON_FILE + ) + + assert credentials._signer.key_id == info["private_key_id"] + assert credentials._issuer == info["client_email"] + assert credentials._subject == info["client_email"] + + def test_from_service_account_file_args(self): + info = SERVICE_ACCOUNT_INFO.copy() + + credentials = jwt.OnDemandCredentials.from_service_account_file( + SERVICE_ACCOUNT_JSON_FILE, + subject=self.SUBJECT, + additional_claims=self.ADDITIONAL_CLAIMS, + ) + + assert credentials._signer.key_id == info["private_key_id"] + assert credentials._issuer == info["client_email"] + assert credentials._subject == self.SUBJECT + assert credentials._additional_claims == self.ADDITIONAL_CLAIMS + + def test_from_signing_credentials(self): + jwt_from_signing = self.credentials.from_signing_credentials(self.credentials) + jwt_from_info = jwt.OnDemandCredentials.from_service_account_info( + SERVICE_ACCOUNT_INFO + ) + + assert isinstance(jwt_from_signing, jwt.OnDemandCredentials) + assert jwt_from_signing._signer.key_id == jwt_from_info._signer.key_id + assert jwt_from_signing._issuer == jwt_from_info._issuer + assert jwt_from_signing._subject == jwt_from_info._subject + + def test_default_state(self): + # Credentials are *always* valid. + assert self.credentials.valid + # Credentials *never* expire. + assert not self.credentials.expired + + def test_with_claims(self): + new_claims = {"meep": "moop"} + new_credentials = self.credentials.with_claims(additional_claims=new_claims) + + assert new_credentials._signer == self.credentials._signer + assert new_credentials._issuer == self.credentials._issuer + assert new_credentials._subject == self.credentials._subject + assert new_credentials._additional_claims == new_claims + + def test_with_quota_project(self): + quota_project_id = "project-foo" + new_credentials = self.credentials.with_quota_project(quota_project_id) + + assert new_credentials._signer == self.credentials._signer + assert new_credentials._issuer == self.credentials._issuer + assert new_credentials._subject == self.credentials._subject + assert new_credentials._additional_claims == self.credentials._additional_claims + assert new_credentials._quota_project_id == quota_project_id + + def test_sign_bytes(self): + to_sign = b"123" + signature = self.credentials.sign_bytes(to_sign) + assert crypt.verify_signature(to_sign, signature, PUBLIC_CERT_BYTES) + + def test_signer(self): + assert isinstance(self.credentials.signer, crypt.RSASigner) + + def test_signer_email(self): + assert self.credentials.signer_email == SERVICE_ACCOUNT_INFO["client_email"] + + def _verify_token(self, token): + payload = jwt.decode(token, PUBLIC_CERT_BYTES) + assert payload["iss"] == self.SERVICE_ACCOUNT_EMAIL + return payload + + def test_refresh(self): + with pytest.raises(exceptions.RefreshError): + self.credentials.refresh(None) + + def test_before_request(self): + headers = {} + + self.credentials.before_request( + None, "GET", "http://example.com?a=1#3", headers + ) + + _, token = headers["authorization"].split(" ") + payload = self._verify_token(token) + + assert payload["aud"] == "http://example.com" + + # Making another request should re-use the same token. + self.credentials.before_request(None, "GET", "http://example.com?b=2", headers) + + _, new_token = headers["authorization"].split(" ") + + assert new_token == token + + def test_expired_token(self): + self.credentials._cache["audience"] = ( + mock.sentinel.token, + datetime.datetime.min, + ) + + token = self.credentials._get_jwt_for_audience("audience") + + assert token != mock.sentinel.token From 06606ff46cf6226528b0da034f7acf2ecb362464 Mon Sep 17 00:00:00 2001 From: AniBadde Date: Tue, 28 Jul 2020 12:39:53 -0500 Subject: [PATCH 03/20] feat: added the private scope for Response class --- google/auth/transport/aiohttp_req.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/google/auth/transport/aiohttp_req.py b/google/auth/transport/aiohttp_req.py index cf3f7abe1..3c577a5ce 100644 --- a/google/auth/transport/aiohttp_req.py +++ b/google/auth/transport/aiohttp_req.py @@ -51,19 +51,19 @@ class _Response(transport.Response): """ def __init__(self, response): - self.response = response + self._response = response @property def status(self): - return self.response.status + return self._response.status @property def headers(self): - return self.response.headers + return self._response.headers @property def data(self): - return self.response.content + return self._response.content class Request(transport.Request): @@ -214,6 +214,7 @@ async def request( **kwargs ): + if self._auth_request is None: self._auth_request_session = aiohttp.ClientSession() auth_request = Request(self._auth_request_session) @@ -292,6 +293,6 @@ async def request( **kwargs ) - await self._auth_request_session.close() + #await self._auth_request_session.close() return response From cd57fadc91e0fc938e7957c6372902e23bd540cf Mon Sep 17 00:00:00 2001 From: AniBadde Date: Tue, 28 Jul 2020 12:42:08 -0500 Subject: [PATCH 04/20] feat: added docstring for Auth Session request method --- google/auth/transport/aiohttp_req.py | 34 +++++++++++++++++++++++++++- 1 file changed, 33 insertions(+), 1 deletion(-) diff --git a/google/auth/transport/aiohttp_req.py b/google/auth/transport/aiohttp_req.py index 3c577a5ce..77fa4eb0f 100644 --- a/google/auth/transport/aiohttp_req.py +++ b/google/auth/transport/aiohttp_req.py @@ -214,6 +214,38 @@ async def request( **kwargs ): + """Implementation of Authorized Session aiohttp request. + + Args: + method: The http request method used (e.g. GET, PUT, DELETE) + + url: The url at which the http request is sent. + + data, headers: These fields parallel the associated data and headers + fields of a regular http request. Using the aiohttp client session to + send the http request allows us to use this parallel corresponding structure + in our Authorized Session class. + + timeout (Optional[Union[float, Tuple[float, float]]]): + The amount of time in seconds to wait for the server response + with each individual request. + + Can also be passed as a tuple (connect_timeout, read_timeout). + See :meth:`requests.Session.request` documentation for details. + + max_allowed_time (Optional[float]): + If the method runs longer than this, a ``Timeout`` exception is + automatically raised. Unlike the ``timeout` parameter, this + value applies to the total method execution time, even if + multiple requests are made under the hood. + + Mind that it is not guaranteed that the timeout error is raised + at ``max_allowed_time`. It might take longer, for example, if + an underlying request takes a lot of time, but the request + itself does not timeout, e.g. if a large file is being + transmitted. The timout error will be raised after such + request completes. + """ if self._auth_request is None: self._auth_request_session = aiohttp.ClientSession() @@ -293,6 +325,6 @@ async def request( **kwargs ) - #await self._auth_request_session.close() + await self._auth_request_session.close() return response From a38e33339da8e148b0a7ce124b8627c92afb337a Mon Sep 17 00:00:00 2001 From: AniBadde Date: Tue, 28 Jul 2020 13:53:57 -0500 Subject: [PATCH 05/20] fix: Changed initialization of client session to within an async context manager --- google/auth/transport/aiohttp_req.py | 125 +++++++++++++-------------- 1 file changed, 62 insertions(+), 63 deletions(-) diff --git a/google/auth/transport/aiohttp_req.py b/google/auth/transport/aiohttp_req.py index 77fa4eb0f..d2f1ab6c8 100644 --- a/google/auth/transport/aiohttp_req.py +++ b/google/auth/transport/aiohttp_req.py @@ -247,56 +247,16 @@ async def request( request completes. """ - if self._auth_request is None: - self._auth_request_session = aiohttp.ClientSession() + async with aiohttp.ClientSession() as self._auth_request_session: auth_request = Request(self._auth_request_session) self._auth_request = auth_request - # Use a kwarg for this instead of an attribute to maintain - # thread-safety. - _credential_refresh_attempt = kwargs.pop("_credential_refresh_attempt", 0) - # Make a copy of the headers. They will be modified by the credentials - # and we want to pass the original headers if we recurse. - request_headers = headers.copy() if headers is not None else {} - - # Do not apply the timeout unconditionally in order to not override the - # _auth_request's default timeout. - auth_request = ( - self._auth_request - if timeout is None - else functools.partial(self._auth_request, timeout=timeout) - ) - - remaining_time = max_allowed_time - - with requests.TimeoutGuard(remaining_time, asyncio.TimeoutError) as guard: - await self.credentials.before_request( - auth_request, method, url, request_headers - ) - - with requests.TimeoutGuard(remaining_time, asyncio.TimeoutError) as guard: - response = await super(AuthorizedSession, self).request( - method, - url, - data=data, - headers=request_headers, - timeout=timeout, - **kwargs - ) - - remaining_time = guard.remaining_timeout - - if ( - response.status in self._refresh_status_codes - and _credential_refresh_attempt < self._max_refresh_attempts - ): - - _LOGGER.info( - "Refreshing credentials due to a %s response. Attempt %s/%s.", - response.status, - _credential_refresh_attempt + 1, - self._max_refresh_attempts, - ) + # Use a kwarg for this instead of an attribute to maintain + # thread-safety. + _credential_refresh_attempt = kwargs.pop("_credential_refresh_attempt", 0) + # Make a copy of the headers. They will be modified by the credentials + # and we want to pass the original headers if we recurse. + request_headers = headers.copy() if headers is not None else {} # Do not apply the timeout unconditionally in order to not override the # _auth_request's default timeout. @@ -306,25 +266,64 @@ async def request( else functools.partial(self._auth_request, timeout=timeout) ) + remaining_time = max_allowed_time + with requests.TimeoutGuard(remaining_time, asyncio.TimeoutError) as guard: - async with self._refresh_lock: - await self._loop.run_in_executor( - None, self.credentials.refresh, auth_request - ) + await self.credentials.before_request( + auth_request, method, url, request_headers + ) - remaining_time = guard.remaining_timeout + with requests.TimeoutGuard(remaining_time, asyncio.TimeoutError) as guard: + response = await super(AuthorizedSession, self).request( + method, + url, + data=data, + headers=request_headers, + timeout=timeout, + **kwargs + ) - return await self.request( - method, - url, - data=data, - headers=headers, - max_allowed_time=remaining_time, - timeout=timeout, - _credential_refresh_attempt=_credential_refresh_attempt + 1, - **kwargs - ) + remaining_time = guard.remaining_timeout - await self._auth_request_session.close() + if ( + response.status in self._refresh_status_codes + and _credential_refresh_attempt < self._max_refresh_attempts + ): + + _LOGGER.info( + "Refreshing credentials due to a %s response. Attempt %s/%s.", + response.status, + _credential_refresh_attempt + 1, + self._max_refresh_attempts, + ) + + # Do not apply the timeout unconditionally in order to not override the + # _auth_request's default timeout. + auth_request = ( + self._auth_request + if timeout is None + else functools.partial(self._auth_request, timeout=timeout) + ) + + with requests.TimeoutGuard( + remaining_time, asyncio.TimeoutError + ) as guard: + async with self._refresh_lock: + await self._loop.run_in_executor( + None, self.credentials.refresh, auth_request + ) + + remaining_time = guard.remaining_timeout + + return await self.request( + method, + url, + data=data, + headers=headers, + max_allowed_time=remaining_time, + timeout=timeout, + _credential_refresh_attempt=_credential_refresh_attempt + 1, + **kwargs + ) return response From 57d6d10b075f767bd7807497ecdaf215718b5cfc Mon Sep 17 00:00:00 2001 From: AniBadde Date: Tue, 28 Jul 2020 14:44:17 -0500 Subject: [PATCH 06/20] changed aiohttp_requests abbreviation for the async authorized session class --- .../{aiohttp_req.py => aiohttp_requests.py} | 6 +--- tests_async/transport/test_aiohttp_req.py | 28 +++++++++---------- 2 files changed, 15 insertions(+), 19 deletions(-) rename google/auth/transport/{aiohttp_req.py => aiohttp_requests.py} (98%) diff --git a/google/auth/transport/aiohttp_req.py b/google/auth/transport/aiohttp_requests.py similarity index 98% rename from google/auth/transport/aiohttp_req.py rename to google/auth/transport/aiohttp_requests.py index d2f1ab6c8..c52780a8d 100644 --- a/google/auth/transport/aiohttp_req.py +++ b/google/auth/transport/aiohttp_requests.py @@ -91,11 +91,7 @@ class Request(transport.Request): """ def __init__(self, session=None): - """ - self.session = None - if not session: - session = aiohttp.ClientSession() - """ + self.session = None async def __call__( diff --git a/tests_async/transport/test_aiohttp_req.py b/tests_async/transport/test_aiohttp_req.py index 1643eaa06..4c3d9717b 100644 --- a/tests_async/transport/test_aiohttp_req.py +++ b/tests_async/transport/test_aiohttp_req.py @@ -19,21 +19,21 @@ from tests_async.transport import async_compliance import google.auth.credentials_async -from google.auth.transport import aiohttp_req +from google.auth.transport import aiohttp_requests import google.auth.transport._mtls_helper class TestRequestResponse(async_compliance.RequestResponseTests): def make_request(self): - return aiohttp_req.Request() + return aiohttp_requests.Request() def make_with_parameter_request(self): http = mock.create_autospec(aiohttp.ClientSession, instance=True) - return aiohttp_req.Request(http) + return aiohttp_requests.Request(http) def test_timeout(self): http = mock.create_autospec(aiohttp.ClientSession, instance=True) - request = google.auth.transport.aiohttp_req.Request(http) + request = google.auth.transport.aiohttp_requests.Request(http) request(url="http://example.com", method="GET", timeout=5) @@ -54,16 +54,16 @@ class TestAuthorizedSession(object): method = "GET" def test_constructor(self): - authed_session = google.auth.transport.aiohttp_req.AuthorizedSession( + authed_session = google.auth.transport.aiohttp_requests.AuthorizedSession( mock.sentinel.credentials ) assert authed_session.credentials == mock.sentinel.credentials def test_constructor_with_auth_request(self): http = mock.create_autospec(aiohttp.ClientSession) - auth_request = google.auth.transport.aiohttp_req.Request(http) + auth_request = google.auth.transport.aiohttp_requests.Request(http) - authed_session = google.auth.transport.aiohttp_req.AuthorizedSession( + authed_session = google.auth.transport.aiohttp_requests.AuthorizedSession( mock.sentinel.credentials, auth_request=auth_request ) @@ -75,7 +75,7 @@ async def test_request(self): credentials = mock.Mock(wraps=CredentialsStub()) mocked.get(self.TEST_URL, status=200, body="test") - session = aiohttp_req.AuthorizedSession(credentials) + session = aiohttp_requests.AuthorizedSession(credentials) resp = await session.request("GET", "http://example.com/") assert resp.status == 200 @@ -88,7 +88,7 @@ async def test_ctx(self): with aioresponses() as mocked: credentials = mock.Mock(wraps=CredentialsStub()) mocked.get("http://test.example.com", payload=dict(foo="bar")) - session = aiohttp_req.AuthorizedSession(credentials) + session = aiohttp_requests.AuthorizedSession(credentials) resp = await session.request("GET", "http://test.example.com") data = await resp.json() @@ -106,7 +106,7 @@ async def test_http_headers(self): headers=dict(connection="keep-alive"), ) - session = aiohttp_req.AuthorizedSession(credentials) + session = aiohttp_requests.AuthorizedSession(credentials) resp = await session.request("POST", "http://example.com") assert resp.headers["Connection"] == "keep-alive" @@ -120,10 +120,10 @@ async def test_regexp_example(self): mocked.get("http://example.com", status=500) mocked.get("http://example.com", status=200) - session1 = aiohttp_req.AuthorizedSession(credentials) + session1 = aiohttp_requests.AuthorizedSession(credentials) resp1 = await session1.request("GET", "http://example.com") - session2 = aiohttp_req.AuthorizedSession(credentials) + session2 = aiohttp_requests.AuthorizedSession(credentials) resp2 = await session2.request("GET", "http://example.com") assert resp1.status == 500 @@ -137,7 +137,7 @@ async def test_request_no_refresh(self): credentials = mock.Mock(wraps=CredentialsStub()) with aioresponses() as mocked: mocked.get("http://example.com", status=200) - authed_session = google.auth.transport.aiohttp_req.AuthorizedSession( + authed_session = google.auth.transport.aiohttp_requests.AuthorizedSession( credentials ) response = await authed_session.request("GET", "http://example.com") @@ -153,7 +153,7 @@ async def test_request_refresh(self): with aioresponses() as mocked: mocked.get("http://example.com", status=401) mocked.get("http://example.com", status=200) - authed_session = google.auth.transport.aiohttp_req.AuthorizedSession( + authed_session = google.auth.transport.aiohttp_requests.AuthorizedSession( credentials ) response = await authed_session.request("GET", "http://example.com") From cc14082ed9dd7779946e34a02155d549aaddbda0 Mon Sep 17 00:00:00 2001 From: AniBadde Date: Tue, 28 Jul 2020 14:55:11 -0500 Subject: [PATCH 07/20] fix: changed abbrevation of the aiohttp_requests file --- google/auth/transport/aiohttp_requests.py | 14 +++++++------- tests_async/conftest.py | 2 +- tests_async/test__default.py | 2 +- tests_async/test__oauth2client.py | 2 +- tests_async/test_credentials.py | 2 +- tests_async/test_jwt.py | 2 +- tests_async/transport/async_compliance.py | 2 +- ...est_aiohttp_req.py => test_aiohttp_requests.py} | 0 8 files changed, 13 insertions(+), 13 deletions(-) rename tests_async/transport/{test_aiohttp_req.py => test_aiohttp_requests.py} (100%) diff --git a/google/auth/transport/aiohttp_requests.py b/google/auth/transport/aiohttp_requests.py index c52780a8d..45f10e317 100644 --- a/google/auth/transport/aiohttp_requests.py +++ b/google/auth/transport/aiohttp_requests.py @@ -76,10 +76,10 @@ class Request(transport.Request): This class can be useful if you want to manually refresh a :class:`~google.auth.credentials.Credentials` instance:: - import google.auth.transport.aiohttp_req + import google.auth.transport.aiohttp_requests import aiohttp - request = google.auth.transport.aiohttp_req.Request() + request = google.auth.transport.aiohttp_requests.Request() credentials.refresh(request) @@ -153,9 +153,9 @@ class AuthorizedSession(aiohttp.ClientSession): This class is used to perform requests to API endpoints that require authorization:: - import google.auth.transport.aiohttp_req + import google.auth.transport.aiohttp_requests - async with aiohttp_req.AuthorizedSession(credentials) as authed_session: + async with aiohttp_requests.AuthorizedSession(credentials) as authed_session: response = await authed_session.request( 'GET', 'https://www.googleapis.com/storage/v1/b') @@ -172,11 +172,11 @@ class AuthorizedSession(aiohttp.ClientSession): refresh the credentials and retry the request. refresh_timeout (Optional[int]): The timeout value in seconds for credential refresh HTTP requests. - auth_request (google.auth.transport.aiohttp_req.Request): + auth_request (google.auth.transport.aiohttp_requests.Request): (Optional) An instance of - :class:`~google.auth.transport.aiohttp_req.Request` used when + :class:`~google.auth.transport.aiohttp_requests.Request` used when refreshing credentials. If not passed, - an instance of :class:`~google.auth.transport.aiohttp_req.Request` + an instance of :class:`~google.auth.transport.aiohttp_requests.Request` is created. """ diff --git a/tests_async/conftest.py b/tests_async/conftest.py index 6989a6375..b4e90f0e8 100644 --- a/tests_async/conftest.py +++ b/tests_async/conftest.py @@ -1,4 +1,4 @@ -# Copyright 2016 Google LLC +# Copyright 2020 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/tests_async/test__default.py b/tests_async/test__default.py index d41edf76c..92300637f 100644 --- a/tests_async/test__default.py +++ b/tests_async/test__default.py @@ -1,4 +1,4 @@ -# Copyright 2016 Google LLC +# Copyright 2020 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/tests_async/test__oauth2client.py b/tests_async/test__oauth2client.py index ce0b7defc..41d07cca8 100644 --- a/tests_async/test__oauth2client.py +++ b/tests_async/test__oauth2client.py @@ -1,4 +1,4 @@ -# Copyright 2016 Google LLC +# Copyright 2020 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/tests_async/test_credentials.py b/tests_async/test_credentials.py index 377f9a7e2..7c65a52bc 100644 --- a/tests_async/test_credentials.py +++ b/tests_async/test_credentials.py @@ -1,4 +1,4 @@ -# Copyright 2016 Google LLC +# Copyright 2020 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/tests_async/test_jwt.py b/tests_async/test_jwt.py index c3720b075..267e2dd26 100644 --- a/tests_async/test_jwt.py +++ b/tests_async/test_jwt.py @@ -1,4 +1,4 @@ -# Copyright 2014 Google Inc. +# Copyright 2020 Google Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/tests_async/transport/async_compliance.py b/tests_async/transport/async_compliance.py index 0f204bd56..4bb027399 100644 --- a/tests_async/transport/async_compliance.py +++ b/tests_async/transport/async_compliance.py @@ -1,4 +1,4 @@ -# Copyright 2016 Google LLC +# Copyright 2020 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/tests_async/transport/test_aiohttp_req.py b/tests_async/transport/test_aiohttp_requests.py similarity index 100% rename from tests_async/transport/test_aiohttp_req.py rename to tests_async/transport/test_aiohttp_requests.py From aa6ece216301b792f92bd71bbe22bd31a96c8f23 Mon Sep 17 00:00:00 2001 From: AniBadde Date: Tue, 28 Jul 2020 15:16:49 -0500 Subject: [PATCH 08/20] fix: comments on PR regarding shared data between requests and aiohttp_requests --- google/auth/transport/aiohttp_requests.py | 32 ++++++------------- ...iohttp_req.py => test_aiohttp_requests.py} | 0 2 files changed, 10 insertions(+), 22 deletions(-) rename tests_async/transport/{test_aiohttp_req.py => test_aiohttp_requests.py} (100%) diff --git a/google/auth/transport/aiohttp_requests.py b/google/auth/transport/aiohttp_requests.py index c52780a8d..46816ea5e 100644 --- a/google/auth/transport/aiohttp_requests.py +++ b/google/auth/transport/aiohttp_requests.py @@ -18,25 +18,14 @@ import asyncio import functools -import logging - import aiohttp import six - from google.auth import exceptions from google.auth import transport from google.auth.transport import requests - -_OAUTH_SCOPES = [ - "https://www.googleapis.com/auth/appengine.apis", - "https://www.googleapis.com/auth/userinfo.email", -] - -_LOGGER = logging.getLogger(__name__) - # Timeout can be re-defined depending on async requirement. Currently made 60s more than # sync timeout. _DEFAULT_TIMEOUT = 180 # in seconds @@ -76,10 +65,9 @@ class Request(transport.Request): This class can be useful if you want to manually refresh a :class:`~google.auth.credentials.Credentials` instance:: - import google.auth.transport.aiohttp_req - import aiohttp + import google.auth.transport.aiohttp_requests - request = google.auth.transport.aiohttp_req.Request() + request = google.auth.transport.aiohttp_requests.Request() credentials.refresh(request) @@ -107,7 +95,7 @@ async def __call__( Make an HTTP request using aiohttp. Args: - url (str): The URI to be requested. + url (str): The URL to be requested. method (str): The HTTP method to use for the request. Defaults to 'GET'. body (bytes): The payload / body in HTTP request. @@ -128,7 +116,7 @@ async def __call__( try: if self.session is None: # pragma: NO COVER self.session = aiohttp.ClientSession() # pragma: NO COVER - _LOGGER.debug("Making request: %s %s", method, url) + requests._LOGGER.debug("Making request: %s %s", method, url) response = await self.session.request( method, url, data=body, headers=headers, timeout=timeout, **kwargs ) @@ -153,9 +141,9 @@ class AuthorizedSession(aiohttp.ClientSession): This class is used to perform requests to API endpoints that require authorization:: - import google.auth.transport.aiohttp_req + import google.auth.transport.aiohttp_requests - async with aiohttp_req.AuthorizedSession(credentials) as authed_session: + async with aiohttp_requests.AuthorizedSession(credentials) as authed_session: response = await authed_session.request( 'GET', 'https://www.googleapis.com/storage/v1/b') @@ -172,11 +160,11 @@ class AuthorizedSession(aiohttp.ClientSession): refresh the credentials and retry the request. refresh_timeout (Optional[int]): The timeout value in seconds for credential refresh HTTP requests. - auth_request (google.auth.transport.aiohttp_req.Request): + auth_request (google.auth.transport.aiohttp_requests.Request): (Optional) An instance of - :class:`~google.auth.transport.aiohttp_req.Request` used when + :class:`~google.auth.transport.aiohttp_requests.Request` used when refreshing credentials. If not passed, - an instance of :class:`~google.auth.transport.aiohttp_req.Request` + an instance of :class:`~google.auth.transport.aiohttp_requests.Request` is created. """ @@ -286,7 +274,7 @@ async def request( and _credential_refresh_attempt < self._max_refresh_attempts ): - _LOGGER.info( + requests._LOGGER.info( "Refreshing credentials due to a %s response. Attempt %s/%s.", response.status, _credential_refresh_attempt + 1, diff --git a/tests_async/transport/test_aiohttp_req.py b/tests_async/transport/test_aiohttp_requests.py similarity index 100% rename from tests_async/transport/test_aiohttp_req.py rename to tests_async/transport/test_aiohttp_requests.py From d88839b449c1924b8cbc056cb7327dc503c2148d Mon Sep 17 00:00:00 2001 From: AniBadde Date: Tue, 28 Jul 2020 15:24:04 -0500 Subject: [PATCH 09/20] fix: fixed noxfile test dependency sharing --- noxfile.py | 18 +----------------- 1 file changed, 1 insertion(+), 17 deletions(-) diff --git a/noxfile.py b/noxfile.py index 0bd9544fa..f53ee5450 100644 --- a/noxfile.py +++ b/noxfile.py @@ -32,22 +32,6 @@ "aioresponses", ] -TEST_DEPENDENCIES2 = [ - "flask", - "freezegun", - "mock", - "oauth2client", - "pyopenssl", - "pytest", - "pytest-cov", - "pytest-localserver", - "requests", - "urllib3", - "cryptography", - "responses", - "grpcio", -] - BLACK_VERSION = "black==19.3b0" BLACK_PATHS = [ "google", @@ -107,7 +91,7 @@ def unit(session): @nox.session(python=["2.7", "3.5"]) def unit_prev_versions(session): - session.install(*TEST_DEPENDENCIES2) + session.install(*TEST_DEPENDENCIES[:-2]) session.install(".") session.run( "pytest", "--cov=google.auth", "--cov=google.oauth2", "--cov=tests", "tests" From c62dd1a55b31ee6f4756aa557141b4a4ad25e9e8 Mon Sep 17 00:00:00 2001 From: AniBadde Date: Tue, 28 Jul 2020 16:08:01 -0500 Subject: [PATCH 10/20] fix: fixed the noxfile dependencies between sync and async unit tests --- noxfile.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/noxfile.py b/noxfile.py index f53ee5450..f7ab9da32 100644 --- a/noxfile.py +++ b/noxfile.py @@ -28,10 +28,10 @@ "cryptography", "responses", "grpcio", - "pytest-asyncio", - "aioresponses", ] +ASYNC_DEPENDENCIES = ["pytest-asyncio", "aioresponses"] + BLACK_VERSION = "black==19.3b0" BLACK_PATHS = [ "google", @@ -77,7 +77,7 @@ def blacken(session): @nox.session(python=["3.6", "3.7", "3.8"]) def unit(session): - session.install(*TEST_DEPENDENCIES) + session.install(*(TEST_DEPENDENCIES + ASYNC_DEPENDENCIES)) session.install(".") session.run( "pytest", @@ -91,7 +91,7 @@ def unit(session): @nox.session(python=["2.7", "3.5"]) def unit_prev_versions(session): - session.install(*TEST_DEPENDENCIES[:-2]) + session.install(*TEST_DEPENDENCIES) session.install(".") session.run( "pytest", "--cov=google.auth", "--cov=google.oauth2", "--cov=tests", "tests" From 7080d147ebecb305811a380f45ab5d749b833c4d Mon Sep 17 00:00:00 2001 From: AniBadde Date: Tue, 28 Jul 2020 16:27:19 -0500 Subject: [PATCH 11/20] fix: cover async dependency --- noxfile.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/noxfile.py b/noxfile.py index f7ab9da32..7d42e60ac 100644 --- a/noxfile.py +++ b/noxfile.py @@ -30,7 +30,7 @@ "grpcio", ] -ASYNC_DEPENDENCIES = ["pytest-asyncio", "aioresponses"] +ASYNC_DEPENDENCIES = ["pytest-asyncio", "aioresponses", "aiohttp"] BLACK_VERSION = "black==19.3b0" BLACK_PATHS = [ @@ -77,7 +77,8 @@ def blacken(session): @nox.session(python=["3.6", "3.7", "3.8"]) def unit(session): - session.install(*(TEST_DEPENDENCIES + ASYNC_DEPENDENCIES)) + session.install(*TEST_DEPENDENCIES) + session.install(*(ASYNC_DEPENDENCIES)) session.install(".") session.run( "pytest", @@ -101,6 +102,7 @@ def unit_prev_versions(session): @nox.session(python="3.7") def cover(session): session.install(*TEST_DEPENDENCIES) + session.install(*(ASYNC_DEPENDENCIES)) session.install(".") session.run( "pytest", From b4c306f4584d10d08ffb3277503017efbf0bb35c Mon Sep 17 00:00:00 2001 From: AniBadde Date: Tue, 28 Jul 2020 18:20:49 -0500 Subject: [PATCH 12/20] fix: merge conflict issue with credentials --- google/auth/credentials_async.py | 7 ------- tests_async/test_credentials.py | 4 ---- 2 files changed, 11 deletions(-) diff --git a/google/auth/credentials_async.py b/google/auth/credentials_async.py index f6be8ef1c..5916e45d9 100644 --- a/google/auth/credentials_async.py +++ b/google/auth/credentials_async.py @@ -16,10 +16,7 @@ """Interfaces for credentials.""" import abc -<<<<<<< HEAD import inspect -======= ->>>>>>> async import six @@ -66,14 +63,10 @@ async def before_request(self, request, method, url, headers): # the http request.) if not self.valid: -<<<<<<< HEAD if inspect.iscoroutinefunction(self.refresh): await self.refresh(request) else: self.refresh(request) -======= - self.refresh(request) ->>>>>>> async self.apply(headers) diff --git a/tests_async/test_credentials.py b/tests_async/test_credentials.py index 1966412a8..7c65a52bc 100644 --- a/tests_async/test_credentials.py +++ b/tests_async/test_credentials.py @@ -1,8 +1,4 @@ -<<<<<<< HEAD # Copyright 2020 Google LLC -======= -# Copyright 2016 Google LLC ->>>>>>> async # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. From 9ddb9115f0ffad8d651e00bd6ae48f112550feec Mon Sep 17 00:00:00 2001 From: AniBadde Date: Tue, 28 Jul 2020 18:24:35 -0500 Subject: [PATCH 13/20] fix: merge conflict #2 --- tests_async/transport/async_compliance.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/tests_async/transport/async_compliance.py b/tests_async/transport/async_compliance.py index faec5d2e3..4bb027399 100644 --- a/tests_async/transport/async_compliance.py +++ b/tests_async/transport/async_compliance.py @@ -1,8 +1,4 @@ -<<<<<<< HEAD # Copyright 2020 Google LLC -======= -# Copyright 2016 Google LLC ->>>>>>> async # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. From 8f2d0ef830d26ecbc3bb1403d39cf8b81a193d02 Mon Sep 17 00:00:00 2001 From: AniBadde Date: Wed, 29 Jul 2020 11:06:13 -0500 Subject: [PATCH 14/20] fix: changed duplicated constants for sync-->async inheritance relationship --- google/auth/_default_async.py | 42 +---- google/auth/_oauth2client_async.py | 10 +- google/auth/jwt_async.py | 20 +- google/oauth2/_client_async.py | 15 +- google/oauth2/credentials_async.py | 4 - google/oauth2/service_account_async.py | 2 - noxfile.py | 2 +- tests_async/oauth2/test__client_async.py | 43 ++--- tests_async/oauth2/test_credentials_async.py | 74 ++++---- .../oauth2/test_service_account_async.py | 103 +++++----- tests_async/test__default.py | 52 ++---- tests_async/test__oauth2client.py | 14 +- tests_async/test_jwt.py | 176 ++++++++---------- tests_async/transport/async_compliance.py | 7 +- 14 files changed, 238 insertions(+), 326 deletions(-) diff --git a/google/auth/_default_async.py b/google/auth/_default_async.py index b901aa0a7..2c367e70a 100644 --- a/google/auth/_default_async.py +++ b/google/auth/_default_async.py @@ -1,4 +1,4 @@ -# Copyright 2015 Google Inc. +# Copyright 2020 Google Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -19,42 +19,16 @@ import io import json -import logging import os import warnings import six +from google.auth import _default as default from google.auth import environment_vars from google.auth import exceptions import google.auth.transport._http_client -_LOGGER = logging.getLogger(__name__) - -# Valid types accepted for file-based credentials. -_AUTHORIZED_USER_TYPE = "authorized_user" -_SERVICE_ACCOUNT_TYPE = "service_account" -_VALID_TYPES = (_AUTHORIZED_USER_TYPE, _SERVICE_ACCOUNT_TYPE) - -# Help message when no credentials can be found. -_HELP_MESSAGE = """\ -Could not automatically determine credentials. Please set {env} or \ -explicitly create credentials and re-run the application. For more \ -information, please see \ -https://cloud.google.com/docs/authentication/getting-started -""".format( - env=environment_vars.CREDENTIALS -).strip() - -# Warning when using Cloud SDK user credentials -_CLOUD_SDK_CREDENTIALS_WARNING = """\ -Your application has authenticated using end user credentials from Google \ -Cloud SDK without a quota project. You might receive a "quota exceeded" \ -or "API not enabled" error. We recommend you rerun \ -`gcloud auth application-default login` and make sure a quota project is \ -added. Or you can use service accounts instead. For more information \ -about service accounts, see https://cloud.google.com/docs/authentication/""" - def _warn_about_problematic_credentials(credentials): """Determines if the credentials are problematic. @@ -66,7 +40,7 @@ def _warn_about_problematic_credentials(credentials): from google.auth import _cloud_sdk if credentials.client_id == _cloud_sdk.CLOUD_SDK_CLIENT_ID: - warnings.warn(_CLOUD_SDK_CREDENTIALS_WARNING) + warnings.warn(default._CLOUD_SDK_CREDENTIALS_WARNING) def load_credentials_from_file(filename, scopes=None, quota_project_id=None): @@ -110,7 +84,7 @@ def load_credentials_from_file(filename, scopes=None, quota_project_id=None): # credentials file or an authorized user credentials file. credential_type = info.get("type") - if credential_type == _AUTHORIZED_USER_TYPE: + if credential_type == default._AUTHORIZED_USER_TYPE: from google.oauth2 import credentials_async as credentials try: @@ -125,7 +99,7 @@ def load_credentials_from_file(filename, scopes=None, quota_project_id=None): _warn_about_problematic_credentials(credentials) return credentials, None - elif credential_type == _SERVICE_ACCOUNT_TYPE: + elif credential_type == default._SERVICE_ACCOUNT_TYPE: from google.oauth2 import service_account_async as service_account try: @@ -142,7 +116,7 @@ def load_credentials_from_file(filename, scopes=None, quota_project_id=None): raise exceptions.DefaultCredentialsError( "The file {file} does not have a valid type. " "Type is {type}, expected one of {valid_types}.".format( - file=filename, type=credential_type, valid_types=_VALID_TYPES + file=filename, type=credential_type, valid_types=default._VALID_TYPES ) ) @@ -322,7 +296,7 @@ def default_async(scopes=None, request=None, quota_project_id=None): ).with_quota_project(quota_project_id) effective_project_id = explicit_project_id or project_id if not effective_project_id: - _LOGGER.warning( + default._LOGGER.warning( "No project ID could be determined. Consider running " "`gcloud config set project` or setting the %s " "environment variable", @@ -330,4 +304,4 @@ def default_async(scopes=None, request=None, quota_project_id=None): ) return credentials, effective_project_id - raise exceptions.DefaultCredentialsError(_HELP_MESSAGE) + raise exceptions.DefaultCredentialsError(default._HELP_MESSAGE) diff --git a/google/auth/_oauth2client_async.py b/google/auth/_oauth2client_async.py index 2913134a4..870e55019 100644 --- a/google/auth/_oauth2client_async.py +++ b/google/auth/_oauth2client_async.py @@ -24,12 +24,15 @@ import six from google.auth import _helpers +from google.auth import _oauth2client import google.auth.app_engine import google.auth.compute_engine import google.oauth2.credentials import google.oauth2.service_account_async + try: + import oauth2client import oauth2client.client import oauth2client.contrib.gce import oauth2client.service_account @@ -44,9 +47,6 @@ _HAS_APPENGINE = False -_CONVERT_ERROR_TMPL = "Unable to convert {} to a google-auth credentials class." - - def _convert_oauth2_credentials(credentials): """Converts to :class:`google.oauth2.credentials_async.Credentials`. @@ -167,5 +167,7 @@ def convert(credentials): try: return _CLASS_CONVERSION_MAP[credentials_class](credentials) except KeyError as caught_exc: - new_exc = ValueError(_CONVERT_ERROR_TMPL.format(credentials_class)) + new_exc = ValueError( + _oauth2client._CONVERT_ERROR_TMPL.format(credentials_class) + ) six.raise_from(new_exc, caught_exc) diff --git a/google/auth/jwt_async.py b/google/auth/jwt_async.py index e9782bffb..a5293d73f 100644 --- a/google/auth/jwt_async.py +++ b/google/auth/jwt_async.py @@ -22,19 +22,19 @@ To encode a JWT use :func:`encode`:: from google.auth import crypt - from google.auth import jwt + from google.auth import jwt_async signer = crypt.Signer(private_key) payload = {'some': 'payload'} - encoded = jwt.encode(signer, payload) + encoded = jwt_async.encode(signer, payload) To decode a JWT and verify claims use :func:`decode`:: - claims = jwt.decode(encoded, certs=public_certs) + claims = jwt_async.decode(encoded, certs=public_certs) You can also skip verification:: - claims = jwt.decode(encoded, verify=False) + claims = jwt_async.decode(encoded, verify=False) .. _rfc7519: https://tools.ietf.org/html/rfc7519 @@ -60,14 +60,6 @@ except ImportError: # pragma: NO COVER es256 = None -_DEFAULT_TOKEN_LIFETIME_SECS = 3600 # 1 hour in seconds -_DEFAULT_MAX_CACHE_SIZE = 10 -_ALGORITHM_TO_VERIFIER_CLASS = {"RS256": crypt.RSAVerifier} -_CRYPTOGRAPHY_BASED_ALGORITHMS = frozenset(["ES256"]) - -if es256 is not None: # pragma: NO COVER - _ALGORITHM_TO_VERIFIER_CLASS["ES256"] = es256.ES256Verifier - def encode(signer, payload, header=None, key_id=None): """Make a signed JWT. @@ -234,9 +226,9 @@ def decode(token, certs=None, verify=True, audience=None): key_id = header.get("kid") try: - verifier_cls = _ALGORITHM_TO_VERIFIER_CLASS[key_alg] + verifier_cls = jwt._ALGORITHM_TO_VERIFIER_CLASS[key_alg] except KeyError as exc: - if key_alg in _CRYPTOGRAPHY_BASED_ALGORITHMS: + if key_alg in jwt._CRYPTOGRAPHY_BASED_ALGORITHMS: six.raise_from( ValueError( "The key algorithm {} requires the cryptography package " diff --git a/google/oauth2/_client_async.py b/google/oauth2/_client_async.py index e498b8113..99bae7f55 100644 --- a/google/oauth2/_client_async.py +++ b/google/oauth2/_client_async.py @@ -1,4 +1,4 @@ -# Copyright 2016 Google LLC +# Copyright 2020 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -33,10 +33,7 @@ from google.auth import _helpers from google.auth import exceptions from google.auth import jwt - -_URLENCODED_CONTENT_TYPE = "application/x-www-form-urlencoded" -_JWT_GRANT_TYPE = "urn:ietf:params:oauth:grant-type:jwt-bearer" -_REFRESH_GRANT_TYPE = "refresh_token" +from google.oauth2 import _client as client def _handle_error_response(response_body): @@ -96,7 +93,7 @@ async def _token_endpoint_request(request, token_uri, body): an error. """ body = urllib.parse.urlencode(body).encode("utf-8") - headers = {"content-type": _URLENCODED_CONTENT_TYPE} + headers = {"content-type": client._URLENCODED_CONTENT_TYPE} retry = 0 # retry to fetch token for maximum of two times if any internal failure @@ -162,7 +159,7 @@ async def jwt_grant(request, token_uri, assertion): .. _rfc7523 section 4: https://tools.ietf.org/html/rfc7523#section-4 """ - body = {"assertion": assertion, "grant_type": _JWT_GRANT_TYPE} + body = {"assertion": assertion, "grant_type": client._JWT_GRANT_TYPE} response_data = await _token_endpoint_request(request, token_uri, body) @@ -202,7 +199,7 @@ async def id_token_jwt_grant(request, token_uri, assertion): google.auth.exceptions.RefreshError: If the token endpoint returned an error. """ - body = {"assertion": assertion, "grant_type": _JWT_GRANT_TYPE} + body = {"assertion": assertion, "grant_type": client._JWT_GRANT_TYPE} response_data = await _token_endpoint_request(request, token_uri, body) @@ -251,7 +248,7 @@ async def refresh_grant( .. _rfc6748 section 6: https://tools.ietf.org/html/rfc6749#section-6 """ body = { - "grant_type": _REFRESH_GRANT_TYPE, + "grant_type": client._REFRESH_GRANT_TYPE, "client_id": client_id, "client_secret": client_secret, "refresh_token": refresh_token, diff --git a/google/oauth2/credentials_async.py b/google/oauth2/credentials_async.py index b45feddc6..092aa5781 100644 --- a/google/oauth2/credentials_async.py +++ b/google/oauth2/credentials_async.py @@ -38,10 +38,6 @@ from google.oauth2 import credentials as oauth2_credentials -# The Google OAuth 2.0 token endpoint. Used for authorized user credentials. -_GOOGLE_OAUTH2_TOKEN_ENDPOINT = "https://oauth2.googleapis.com/token" - - class Credentials(oauth2_credentials.Credentials): """Credentials using OAuth 2.0 access and refresh tokens. diff --git a/google/oauth2/service_account_async.py b/google/oauth2/service_account_async.py index 333e06de2..a81a48be4 100644 --- a/google/oauth2/service_account_async.py +++ b/google/oauth2/service_account_async.py @@ -27,8 +27,6 @@ from google.oauth2 import _client_async from google.oauth2 import service_account -_DEFAULT_TOKEN_LIFETIME_SECS = 3600 # 1 hour in seconds - class Credentials( service_account.Credentials, credentials_async.Scoped, credentials_async.Credentials diff --git a/noxfile.py b/noxfile.py index 7d42e60ac..17213a9d0 100644 --- a/noxfile.py +++ b/noxfile.py @@ -1,4 +1,4 @@ -# Copyright 2019 Google LLC +# Copyright 2020 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/tests_async/oauth2/test__client_async.py b/tests_async/oauth2/test__client_async.py index fd5c17dd5..c32a183a6 100644 --- a/tests_async/oauth2/test__client_async.py +++ b/tests_async/oauth2/test__client_async.py @@ -1,4 +1,4 @@ -# Copyright 2016 Google LLC +# Copyright 2020 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -14,7 +14,6 @@ import datetime import json -import os import mock import pytest @@ -23,29 +22,11 @@ from six.moves import urllib from google.auth import _helpers -from google.auth import crypt from google.auth import exceptions from google.auth import jwt_async as jwt +from google.oauth2 import _client as sync_client from google.oauth2 import _client_async as _client - - -DATA_DIR = os.path.join( - os.path.abspath(os.path.join(__file__, "../../..")), "tests/data" -) - -with open(os.path.join(DATA_DIR, "privatekey.pem"), "rb") as fh: - PRIVATE_KEY_BYTES = fh.read() - -SIGNER = crypt.RSASigner.from_string(PRIVATE_KEY_BYTES, "1") - -SCOPES_AS_LIST = [ - "https://www.googleapis.com/auth/pubsub", - "https://www.googleapis.com/auth/logging.write", -] -SCOPES_AS_STRING = ( - "https://www.googleapis.com/auth/pubsub" - " https://www.googleapis.com/auth/logging.write" -) +from tests.oauth2 import test__client as test_client def test__handle_error_response(): @@ -158,7 +139,8 @@ async def test_jwt_grant(utcnow): # Check request call verify_request_params( - request, {"grant_type": _client._JWT_GRANT_TYPE, "assertion": "assertion_value"} + request, + {"grant_type": sync_client._JWT_GRANT_TYPE, "assertion": "assertion_value"}, ) # Check result @@ -185,7 +167,7 @@ async def test_jwt_grant_no_access_token(): async def test_id_token_jwt_grant(): now = _helpers.utcnow() id_token_expiry = _helpers.datetime_to_secs(now) - id_token = jwt.encode(SIGNER, {"exp": id_token_expiry}).decode("utf-8") + id_token = jwt.encode(test_client.SIGNER, {"exp": id_token_expiry}).decode("utf-8") request = make_request({"id_token": id_token, "extra": "data"}) token, expiry, extra_data = await _client.id_token_jwt_grant( @@ -194,7 +176,8 @@ async def test_id_token_jwt_grant(): # Check request call verify_request_params( - request, {"grant_type": _client._JWT_GRANT_TYPE, "assertion": "assertion_value"} + request, + {"grant_type": sync_client._JWT_GRANT_TYPE, "assertion": "assertion_value"}, ) # Check result @@ -241,7 +224,7 @@ async def test_refresh_grant(unused_utcnow): verify_request_params( request, { - "grant_type": _client._REFRESH_GRANT_TYPE, + "grant_type": sync_client._REFRESH_GRANT_TYPE, "refresh_token": "refresh_token", "client_id": "client_id", "client_secret": "client_secret", @@ -264,7 +247,7 @@ async def test_refresh_grant_with_scopes(unused_utcnow): "refresh_token": "new_refresh_token", "expires_in": 500, "extra": "data", - "scope": SCOPES_AS_STRING, + "scope": test_client.SCOPES_AS_STRING, } ) @@ -274,18 +257,18 @@ async def test_refresh_grant_with_scopes(unused_utcnow): "refresh_token", "client_id", "client_secret", - SCOPES_AS_LIST, + test_client.SCOPES_AS_LIST, ) # Check request call. verify_request_params( request, { - "grant_type": _client._REFRESH_GRANT_TYPE, + "grant_type": sync_client._REFRESH_GRANT_TYPE, "refresh_token": "refresh_token", "client_id": "client_id", "client_secret": "client_secret", - "scope": SCOPES_AS_STRING, + "scope": test_client.SCOPES_AS_STRING, }, ) diff --git a/tests_async/oauth2/test_credentials_async.py b/tests_async/oauth2/test_credentials_async.py index a3614b3bd..a5ffbaffc 100644 --- a/tests_async/oauth2/test_credentials_async.py +++ b/tests_async/oauth2/test_credentials_async.py @@ -1,4 +1,4 @@ -# Copyright 2016 Google LLC +# Copyright 2020 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -23,16 +23,9 @@ from google.auth import _helpers from google.auth import exceptions -from google.oauth2 import credentials_async as credentials - -DATA_DIR = os.path.join( - os.path.abspath(os.path.join(__file__, "../../..")), "tests/data" -) - -AUTH_USER_JSON_FILE = os.path.join(DATA_DIR, "authorized_user.json") - -with open(AUTH_USER_JSON_FILE, "r") as fh: - AUTH_USER_INFO = json.load(fh) +from google.oauth2 import credentials +from google.oauth2 import credentials_async +from tests.oauth2 import test_credentials class TestCredentials: @@ -44,7 +37,7 @@ class TestCredentials: @classmethod def make_credentials(cls): - return credentials.Credentials( + return credentials_async.Credentials( token=None, refresh_token=cls.REFRESH_TOKEN, token_uri=cls.TOKEN_URI, @@ -87,10 +80,10 @@ async def test_refresh_success(self, unused_utcnow, refresh_grant): ) request = mock.AsyncMock(spec=["transport.Request"]) - credentials = self.make_credentials() + creds = self.make_credentials() # Refresh credentials - await credentials.refresh(request) + await creds.refresh(request) # Check jwt grant call. refresh_grant.assert_called_with( @@ -103,18 +96,18 @@ async def test_refresh_success(self, unused_utcnow, refresh_grant): ) # Check that the credentials have the token and expiry - assert credentials.token == token - assert credentials.expiry == expiry - assert credentials.id_token == mock.sentinel.id_token + assert creds.token == token + assert creds.expiry == expiry + assert creds.id_token == mock.sentinel.id_token # Check that the credentials are valid (have a token and are not # expired) - assert credentials.valid + assert creds.valid @pytest.mark.asyncio async def test_refresh_no_refresh_token(self): request = mock.AsyncMock(spec=["transport.Request"]) - credentials_ = credentials.Credentials(token=None, refresh_token=None) + credentials_ = credentials_async.Credentials(token=None, refresh_token=None) with pytest.raises(exceptions.RefreshError, match="necessary fields"): await credentials_.refresh(request) @@ -146,7 +139,7 @@ async def test_credentials_with_scopes_requested_refresh_success( ) request = mock.AsyncMock(spec=["transport.Request"]) - creds = credentials.Credentials( + creds = credentials_async.Credentials( token=None, refresh_token=self.REFRESH_TOKEN, token_uri=self.TOKEN_URI, @@ -206,7 +199,7 @@ async def test_credentials_with_scopes_returned_refresh_success( ) request = mock.AsyncMock(spec=["transport.Request"]) - creds = credentials.Credentials( + creds = credentials_async.Credentials( token=None, refresh_token=self.REFRESH_TOKEN, token_uri=self.TOKEN_URI, @@ -267,7 +260,7 @@ async def test_credentials_with_scopes_refresh_failure_raises_refresh_error( ) request = mock.AsyncMock(spec=["transport.Request"]) - creds = credentials.Credentials( + creds = credentials_async.Credentials( token=None, refresh_token=self.REFRESH_TOKEN, token_uri=self.TOKEN_URI, @@ -303,7 +296,7 @@ async def test_credentials_with_scopes_refresh_failure_raises_refresh_error( assert creds.valid def test_apply_with_quota_project_id(self): - creds = credentials.Credentials( + creds = credentials_async.Credentials( token="token", refresh_token=self.REFRESH_TOKEN, token_uri=self.TOKEN_URI, @@ -317,7 +310,7 @@ def test_apply_with_quota_project_id(self): assert headers["x-goog-user-project"] == "quota-project-123" def test_apply_with_no_quota_project_id(self): - creds = credentials.Credentials( + creds = credentials_async.Credentials( token="token", refresh_token=self.REFRESH_TOKEN, token_uri=self.TOKEN_URI, @@ -330,7 +323,7 @@ def test_apply_with_no_quota_project_id(self): assert "x-goog-user-project" not in headers def test_with_quota_project(self): - creds = credentials.Credentials( + creds = credentials_async.Credentials( token="token", refresh_token=self.REFRESH_TOKEN, token_uri=self.TOKEN_URI, @@ -346,9 +339,9 @@ def test_with_quota_project(self): assert "x-goog-user-project" in headers def test_from_authorized_user_info(self): - info = AUTH_USER_INFO.copy() + info = test_credentials.AUTH_USER_INFO.copy() - creds = credentials.Credentials.from_authorized_user_info(info) + creds = credentials_async.Credentials.from_authorized_user_info(info) assert creds.client_secret == info["client_secret"] assert creds.client_id == info["client_id"] assert creds.refresh_token == info["refresh_token"] @@ -356,7 +349,7 @@ def test_from_authorized_user_info(self): assert creds.scopes is None scopes = ["email", "profile"] - creds = credentials.Credentials.from_authorized_user_info(info, scopes) + creds = credentials_async.Credentials.from_authorized_user_info(info, scopes) assert creds.client_secret == info["client_secret"] assert creds.client_id == info["client_id"] assert creds.refresh_token == info["refresh_token"] @@ -364,9 +357,11 @@ def test_from_authorized_user_info(self): assert creds.scopes == scopes def test_from_authorized_user_file(self): - info = AUTH_USER_INFO.copy() + info = test_credentials.AUTH_USER_INFO.copy() - creds = credentials.Credentials.from_authorized_user_file(AUTH_USER_JSON_FILE) + creds = credentials_async.Credentials.from_authorized_user_file( + test_credentials.AUTH_USER_JSON_FILE + ) assert creds.client_secret == info["client_secret"] assert creds.client_id == info["client_id"] assert creds.refresh_token == info["refresh_token"] @@ -374,8 +369,8 @@ def test_from_authorized_user_file(self): assert creds.scopes is None scopes = ["email", "profile"] - creds = credentials.Credentials.from_authorized_user_file( - AUTH_USER_JSON_FILE, scopes + creds = credentials_async.Credentials.from_authorized_user_file( + test_credentials.AUTH_USER_JSON_FILE, scopes ) assert creds.client_secret == info["client_secret"] assert creds.client_id == info["client_id"] @@ -384,8 +379,8 @@ def test_from_authorized_user_file(self): assert creds.scopes == scopes def test_to_json(self): - info = AUTH_USER_INFO.copy() - creds = credentials.Credentials.from_authorized_user_info(info) + info = test_credentials.AUTH_USER_INFO.copy() + creds = credentials_async.Credentials.from_authorized_user_info(info) # Test with no `strip` arg json_output = creds.to_json() @@ -439,7 +434,8 @@ def test_unpickle_old_credentials_pickle(self): # make sure a credentials file pickled with an older # library version (google-auth==1.5.1) can be unpickled with open( - os.path.join(DATA_DIR, "old_oauth_credentials_py3.pickle"), "rb" + os.path.join(test_credentials.DATA_DIR, "old_oauth_credentials_py3.pickle"), + "rb", ) as f: credentials = pickle.load(f) assert credentials.quota_project_id is None @@ -447,7 +443,7 @@ def test_unpickle_old_credentials_pickle(self): class TestUserAccessTokenCredentials(object): def test_instance(self): - cred = credentials.UserAccessTokenCredentials() + cred = credentials_async.UserAccessTokenCredentials() assert cred._account is None cred = cred.with_account("account") @@ -456,12 +452,12 @@ def test_instance(self): @mock.patch("google.auth._cloud_sdk.get_auth_access_token", autospec=True) def test_refresh(self, get_auth_access_token): get_auth_access_token.return_value = "access_token" - cred = credentials.UserAccessTokenCredentials() + cred = credentials_async.UserAccessTokenCredentials() cred.refresh(None) assert cred.token == "access_token" def test_with_quota_project(self): - cred = credentials.UserAccessTokenCredentials() + cred = credentials_async.UserAccessTokenCredentials() quota_project_cred = cred.with_quota_project("project-foo") assert quota_project_cred._quota_project_id == "project-foo" @@ -476,7 +472,7 @@ def test_with_quota_project(self): autospec=True, ) def test_before_request(self, refresh, apply): - cred = credentials.UserAccessTokenCredentials() + cred = credentials_async.UserAccessTokenCredentials() cred.before_request(mock.Mock(), "GET", "https://example.com", {}) refresh.assert_called() apply.assert_called() diff --git a/tests_async/oauth2/test_service_account_async.py b/tests_async/oauth2/test_service_account_async.py index 22a1876da..65c86b442 100644 --- a/tests_async/oauth2/test_service_account_async.py +++ b/tests_async/oauth2/test_service_account_async.py @@ -1,4 +1,4 @@ -# Copyright 2016 Google LLC +# Copyright 2020 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -13,8 +13,6 @@ # limitations under the License. import datetime -import json -import os import mock import pytest @@ -24,27 +22,7 @@ from google.auth import jwt from google.auth import transport from google.oauth2 import service_account_async as service_account - - -DATA_DIR = os.path.join( - os.path.abspath(os.path.join(__file__, "../../..")), "tests/data" -) - -with open(os.path.join(DATA_DIR, "privatekey.pem"), "rb") as fh: - PRIVATE_KEY_BYTES = fh.read() - -with open(os.path.join(DATA_DIR, "public_cert.pem"), "rb") as fh: - PUBLIC_CERT_BYTES = fh.read() - -with open(os.path.join(DATA_DIR, "other_cert.pem"), "rb") as fh: - OTHER_CERT_BYTES = fh.read() - -SERVICE_ACCOUNT_JSON_FILE = os.path.join(DATA_DIR, "service_account.json") - -with open(SERVICE_ACCOUNT_JSON_FILE, "r") as fh: - SERVICE_ACCOUNT_INFO = json.load(fh) - -SIGNER = crypt.RSASigner.from_string(PRIVATE_KEY_BYTES, "1") +from tests.oauth2 import test_service_account class TestCredentials(object): @@ -54,20 +32,29 @@ class TestCredentials(object): @classmethod def make_credentials(cls): return service_account.Credentials( - SIGNER, cls.SERVICE_ACCOUNT_EMAIL, cls.TOKEN_URI + test_service_account.SIGNER, cls.SERVICE_ACCOUNT_EMAIL, cls.TOKEN_URI ) def test_from_service_account_info(self): credentials = service_account.Credentials.from_service_account_info( - SERVICE_ACCOUNT_INFO + test_service_account.SERVICE_ACCOUNT_INFO ) - assert credentials._signer.key_id == SERVICE_ACCOUNT_INFO["private_key_id"] - assert credentials.service_account_email == SERVICE_ACCOUNT_INFO["client_email"] - assert credentials._token_uri == SERVICE_ACCOUNT_INFO["token_uri"] + assert ( + credentials._signer.key_id + == test_service_account.SERVICE_ACCOUNT_INFO["private_key_id"] + ) + assert ( + credentials.service_account_email + == test_service_account.SERVICE_ACCOUNT_INFO["client_email"] + ) + assert ( + credentials._token_uri + == test_service_account.SERVICE_ACCOUNT_INFO["token_uri"] + ) def test_from_service_account_info_args(self): - info = SERVICE_ACCOUNT_INFO.copy() + info = test_service_account.SERVICE_ACCOUNT_INFO.copy() scopes = ["email", "profile"] subject = "subject" additional_claims = {"meta": "data"} @@ -85,10 +72,10 @@ def test_from_service_account_info_args(self): assert credentials._additional_claims == additional_claims def test_from_service_account_file(self): - info = SERVICE_ACCOUNT_INFO.copy() + info = test_service_account.SERVICE_ACCOUNT_INFO.copy() credentials = service_account.Credentials.from_service_account_file( - SERVICE_ACCOUNT_JSON_FILE + test_service_account.SERVICE_ACCOUNT_JSON_FILE ) assert credentials.service_account_email == info["client_email"] @@ -97,13 +84,13 @@ def test_from_service_account_file(self): assert credentials._token_uri == info["token_uri"] def test_from_service_account_file_args(self): - info = SERVICE_ACCOUNT_INFO.copy() + info = test_service_account.SERVICE_ACCOUNT_INFO.copy() scopes = ["email", "profile"] subject = "subject" additional_claims = {"meta": "data"} credentials = service_account.Credentials.from_service_account_file( - SERVICE_ACCOUNT_JSON_FILE, + test_service_account.SERVICE_ACCOUNT_JSON_FILE, subject=subject, scopes=scopes, additional_claims=additional_claims, @@ -129,7 +116,9 @@ def test_sign_bytes(self): credentials = self.make_credentials() to_sign = b"123" signature = credentials.sign_bytes(to_sign) - assert crypt.verify_signature(to_sign, signature, PUBLIC_CERT_BYTES) + assert crypt.verify_signature( + to_sign, signature, test_service_account.PUBLIC_CERT_BYTES + ) def test_signer(self): credentials = self.make_credentials() @@ -161,7 +150,7 @@ def test_with_quota_project(self): def test__make_authorization_grant_assertion(self): credentials = self.make_credentials() token = credentials._make_authorization_grant_assertion() - payload = jwt.decode(token, PUBLIC_CERT_BYTES) + payload = jwt.decode(token, test_service_account.PUBLIC_CERT_BYTES) assert payload["iss"] == self.SERVICE_ACCOUNT_EMAIL assert payload["aud"] == self.TOKEN_URI @@ -170,7 +159,7 @@ def test__make_authorization_grant_assertion_scoped(self): scopes = ["email", "profile"] credentials = credentials.with_scopes(scopes) token = credentials._make_authorization_grant_assertion() - payload = jwt.decode(token, PUBLIC_CERT_BYTES) + payload = jwt.decode(token, test_service_account.PUBLIC_CERT_BYTES) assert payload["scope"] == "email profile" def test__make_authorization_grant_assertion_subject(self): @@ -178,7 +167,7 @@ def test__make_authorization_grant_assertion_subject(self): subject = "user@example.com" credentials = credentials.with_subject(subject) token = credentials._make_authorization_grant_assertion() - payload = jwt.decode(token, PUBLIC_CERT_BYTES) + payload = jwt.decode(token, test_service_account.PUBLIC_CERT_BYTES) assert payload["sub"] == subject @mock.patch("google.oauth2._client_async.jwt_grant", autospec=True) @@ -202,7 +191,7 @@ async def test_refresh_success(self, jwt_grant): called_request, token_uri, assertion = jwt_grant.call_args[0] assert called_request == request assert token_uri == credentials._token_uri - assert jwt.decode(assertion, PUBLIC_CERT_BYTES) + assert jwt.decode(assertion, test_service_account.PUBLIC_CERT_BYTES) # No further assertion done on the token, as there are separate tests # for checking the authorization grant assertion. @@ -246,24 +235,38 @@ class TestIDTokenCredentials(object): @classmethod def make_credentials(cls): return service_account.IDTokenCredentials( - SIGNER, cls.SERVICE_ACCOUNT_EMAIL, cls.TOKEN_URI, cls.TARGET_AUDIENCE + test_service_account.SIGNER, + cls.SERVICE_ACCOUNT_EMAIL, + cls.TOKEN_URI, + cls.TARGET_AUDIENCE, ) def test_from_service_account_info(self): credentials = service_account.IDTokenCredentials.from_service_account_info( - SERVICE_ACCOUNT_INFO, target_audience=self.TARGET_AUDIENCE + test_service_account.SERVICE_ACCOUNT_INFO, + target_audience=self.TARGET_AUDIENCE, ) - assert credentials._signer.key_id == SERVICE_ACCOUNT_INFO["private_key_id"] - assert credentials.service_account_email == SERVICE_ACCOUNT_INFO["client_email"] - assert credentials._token_uri == SERVICE_ACCOUNT_INFO["token_uri"] + assert ( + credentials._signer.key_id + == test_service_account.SERVICE_ACCOUNT_INFO["private_key_id"] + ) + assert ( + credentials.service_account_email + == test_service_account.SERVICE_ACCOUNT_INFO["client_email"] + ) + assert ( + credentials._token_uri + == test_service_account.SERVICE_ACCOUNT_INFO["token_uri"] + ) assert credentials._target_audience == self.TARGET_AUDIENCE def test_from_service_account_file(self): - info = SERVICE_ACCOUNT_INFO.copy() + info = test_service_account.SERVICE_ACCOUNT_INFO.copy() credentials = service_account.IDTokenCredentials.from_service_account_file( - SERVICE_ACCOUNT_JSON_FILE, target_audience=self.TARGET_AUDIENCE + test_service_account.SERVICE_ACCOUNT_JSON_FILE, + target_audience=self.TARGET_AUDIENCE, ) assert credentials.service_account_email == info["client_email"] @@ -281,7 +284,9 @@ def test_sign_bytes(self): credentials = self.make_credentials() to_sign = b"123" signature = credentials.sign_bytes(to_sign) - assert crypt.verify_signature(to_sign, signature, PUBLIC_CERT_BYTES) + assert crypt.verify_signature( + to_sign, signature, test_service_account.PUBLIC_CERT_BYTES + ) def test_signer(self): credentials = self.make_credentials() @@ -304,7 +309,7 @@ def test_with_quota_project(self): def test__make_authorization_grant_assertion(self): credentials = self.make_credentials() token = credentials._make_authorization_grant_assertion() - payload = jwt.decode(token, PUBLIC_CERT_BYTES) + payload = jwt.decode(token, test_service_account.PUBLIC_CERT_BYTES) assert payload["iss"] == self.SERVICE_ACCOUNT_EMAIL assert payload["aud"] == self.TOKEN_URI assert payload["target_audience"] == self.TARGET_AUDIENCE @@ -331,7 +336,7 @@ async def test_refresh_success(self, id_token_jwt_grant): called_request, token_uri, assertion = id_token_jwt_grant.call_args[0] assert called_request == request assert token_uri == credentials._token_uri - assert jwt.decode(assertion, PUBLIC_CERT_BYTES) + assert jwt.decode(assertion, test_service_account.PUBLIC_CERT_BYTES) # No further assertion done on the token, as there are separate tests # for checking the authorization grant assertion. diff --git a/tests_async/test__default.py b/tests_async/test__default.py index 92300637f..3fbd64b34 100644 --- a/tests_async/test__default.py +++ b/tests_async/test__default.py @@ -26,28 +26,7 @@ from google.auth import exceptions from google.oauth2 import service_account_async as service_account import google.oauth2.credentials - - -DATA_DIR = os.path.join(os.path.abspath(os.path.join(__file__, "../..")), "tests/data") -AUTHORIZED_USER_FILE = os.path.join(DATA_DIR, "authorized_user.json") - -with open(AUTHORIZED_USER_FILE) as fh: - AUTHORIZED_USER_FILE_DATA = json.load(fh) - -AUTHORIZED_USER_CLOUD_SDK_FILE = os.path.join( - DATA_DIR, "authorized_user_cloud_sdk.json" -) - -AUTHORIZED_USER_CLOUD_SDK_WITH_QUOTA_PROJECT_ID_FILE = os.path.join( - DATA_DIR, "authorized_user_cloud_sdk_with_quota_project_id.json" -) - -SERVICE_ACCOUNT_FILE = os.path.join(DATA_DIR, "service_account.json") - -CLIENT_SECRETS_FILE = os.path.join(DATA_DIR, "client_secrets.json") - -with open(SERVICE_ACCOUNT_FILE) as fh: - SERVICE_ACCOUNT_FILE_DATA = json.load(fh) +from tests import test__default as test_default MOCK_CREDENTIALS = mock.Mock(spec=credentials.Credentials) MOCK_CREDENTIALS.with_quota_project.return_value = MOCK_CREDENTIALS @@ -87,7 +66,9 @@ def test_load_credentials_from_file_invalid_type(tmpdir): def test_load_credentials_from_file_authorized_user(): - credentials, project_id = _default.load_credentials_from_file(AUTHORIZED_USER_FILE) + credentials, project_id = _default.load_credentials_from_file( + test_default.AUTHORIZED_USER_FILE + ) assert isinstance(credentials, google.oauth2.credentials_async.Credentials) assert project_id is None @@ -96,7 +77,7 @@ def test_load_credentials_from_file_no_type(tmpdir): # use the client_secrets.json, which is valid json but not a # loadable credentials type with pytest.raises(exceptions.DefaultCredentialsError) as excinfo: - _default.load_credentials_from_file(CLIENT_SECRETS_FILE) + _default.load_credentials_from_file(test_default.CLIENT_SECRETS_FILE) assert excinfo.match(r"does not have a valid type") assert excinfo.match(r"Type is None") @@ -116,14 +97,14 @@ def test_load_credentials_from_file_authorized_user_bad_format(tmpdir): def test_load_credentials_from_file_authorized_user_cloud_sdk(): with pytest.warns(UserWarning, match="Cloud SDK"): credentials, project_id = _default.load_credentials_from_file( - AUTHORIZED_USER_CLOUD_SDK_FILE + test_default.AUTHORIZED_USER_CLOUD_SDK_FILE ) assert isinstance(credentials, google.oauth2.credentials_async.Credentials) assert project_id is None # No warning if the json file has quota project id. credentials, project_id = _default.load_credentials_from_file( - AUTHORIZED_USER_CLOUD_SDK_WITH_QUOTA_PROJECT_ID_FILE + test_default.AUTHORIZED_USER_CLOUD_SDK_WITH_QUOTA_PROJECT_ID_FILE ) assert isinstance(credentials, google.oauth2.credentials_async.Credentials) assert project_id is None @@ -132,7 +113,7 @@ def test_load_credentials_from_file_authorized_user_cloud_sdk(): def test_load_credentials_from_file_authorized_user_cloud_sdk_with_scopes(): with pytest.warns(UserWarning, match="Cloud SDK"): credentials, project_id = _default.load_credentials_from_file( - AUTHORIZED_USER_CLOUD_SDK_FILE, + test_default.AUTHORIZED_USER_CLOUD_SDK_FILE, scopes=["https://www.google.com/calendar/feeds"], ) assert isinstance(credentials, google.oauth2.credentials_async.Credentials) @@ -142,7 +123,7 @@ def test_load_credentials_from_file_authorized_user_cloud_sdk_with_scopes(): def test_load_credentials_from_file_authorized_user_cloud_sdk_with_quota_project(): credentials, project_id = _default.load_credentials_from_file( - AUTHORIZED_USER_CLOUD_SDK_FILE, quota_project_id="project-foo" + test_default.AUTHORIZED_USER_CLOUD_SDK_FILE, quota_project_id="project-foo" ) assert isinstance(credentials, google.oauth2.credentials_async.Credentials) @@ -151,17 +132,20 @@ def test_load_credentials_from_file_authorized_user_cloud_sdk_with_quota_project def test_load_credentials_from_file_service_account(): - credentials, project_id = _default.load_credentials_from_file(SERVICE_ACCOUNT_FILE) + credentials, project_id = _default.load_credentials_from_file( + test_default.SERVICE_ACCOUNT_FILE + ) assert isinstance(credentials, service_account.Credentials) - assert project_id == SERVICE_ACCOUNT_FILE_DATA["project_id"] + assert project_id == test_default.SERVICE_ACCOUNT_FILE_DATA["project_id"] def test_load_credentials_from_file_service_account_with_scopes(): credentials, project_id = _default.load_credentials_from_file( - SERVICE_ACCOUNT_FILE, scopes=["https://www.google.com/calendar/feeds"] + test_default.SERVICE_ACCOUNT_FILE, + scopes=["https://www.google.com/calendar/feeds"], ) assert isinstance(credentials, service_account.Credentials) - assert project_id == SERVICE_ACCOUNT_FILE_DATA["project_id"] + assert project_id == test_default.SERVICE_ACCOUNT_FILE_DATA["project_id"] assert credentials.scopes == ["https://www.google.com/calendar/feeds"] @@ -208,13 +192,13 @@ def test__get_explicit_environ_credentials_no_project_id(load, monkeypatch): "google.auth._cloud_sdk.get_application_default_credentials_path", autospec=True ) def test__get_gcloud_sdk_credentials(get_adc_path, load): - get_adc_path.return_value = SERVICE_ACCOUNT_FILE + get_adc_path.return_value = test_default.SERVICE_ACCOUNT_FILE credentials, project_id = _default._get_gcloud_sdk_credentials() assert credentials is MOCK_CREDENTIALS assert project_id is mock.sentinel.project_id - load.assert_called_with(SERVICE_ACCOUNT_FILE) + load.assert_called_with(test_default.SERVICE_ACCOUNT_FILE) @mock.patch( diff --git a/tests_async/test__oauth2client.py b/tests_async/test__oauth2client.py index 41d07cca8..af38dedda 100644 --- a/tests_async/test__oauth2client.py +++ b/tests_async/test__oauth2client.py @@ -13,7 +13,6 @@ # limitations under the License. import datetime -import os import sys import mock @@ -24,10 +23,7 @@ from six.moves import reload_module from google.auth import _oauth2client_async as _oauth2client - - -DATA_DIR = os.path.join(os.path.abspath(os.path.join(__file__, "../..")), "tests/data") -SERVICE_ACCOUNT_JSON_FILE = os.path.join(DATA_DIR, "service_account.json") +from tests import test__oauth2client as test_oauth2client def test__convert_oauth2_credentials(): @@ -54,7 +50,9 @@ def test__convert_oauth2_credentials(): def test__convert_service_account_credentials(): old_class = oauth2client.service_account.ServiceAccountCredentials - old_credentials = old_class.from_json_keyfile_name(SERVICE_ACCOUNT_JSON_FILE) + old_credentials = old_class.from_json_keyfile_name( + test_oauth2client.SERVICE_ACCOUNT_JSON_FILE + ) new_credentials = _oauth2client._convert_service_account_credentials( old_credentials @@ -69,7 +67,9 @@ def test__convert_service_account_credentials(): def test__convert_service_account_credentials_with_jwt(): old_class = oauth2client.service_account._JWTAccessCredentials - old_credentials = old_class.from_json_keyfile_name(SERVICE_ACCOUNT_JSON_FILE) + old_credentials = old_class.from_json_keyfile_name( + test_oauth2client.SERVICE_ACCOUNT_JSON_FILE + ) new_credentials = _oauth2client._convert_service_account_credentials( old_credentials diff --git a/tests_async/test_jwt.py b/tests_async/test_jwt.py index 267e2dd26..aa9f39f23 100644 --- a/tests_async/test_jwt.py +++ b/tests_async/test_jwt.py @@ -15,7 +15,6 @@ import base64 import datetime import json -import os import mock import pytest @@ -23,48 +22,27 @@ from google.auth import _helpers from google.auth import crypt from google.auth import exceptions -from google.auth import jwt_async as jwt - - -DATA_DIR = os.path.join(os.path.abspath(os.path.join(__file__, "../..")), "tests/data") - -with open(os.path.join(DATA_DIR, "privatekey.pem"), "rb") as fh: - PRIVATE_KEY_BYTES = fh.read() - -with open(os.path.join(DATA_DIR, "public_cert.pem"), "rb") as fh: - PUBLIC_CERT_BYTES = fh.read() - -with open(os.path.join(DATA_DIR, "other_cert.pem"), "rb") as fh: - OTHER_CERT_BYTES = fh.read() - -with open(os.path.join(DATA_DIR, "es256_privatekey.pem"), "rb") as fh: - EC_PRIVATE_KEY_BYTES = fh.read() - -with open(os.path.join(DATA_DIR, "es256_public_cert.pem"), "rb") as fh: - EC_PUBLIC_CERT_BYTES = fh.read() - -SERVICE_ACCOUNT_JSON_FILE = os.path.join(DATA_DIR, "service_account.json") - -with open(SERVICE_ACCOUNT_JSON_FILE, "r") as fh: - SERVICE_ACCOUNT_INFO = json.load(fh) +from google.auth import jwt +from google.auth import jwt_async +from tests import test_jwt @pytest.fixture def signer(): - return crypt.RSASigner.from_string(PRIVATE_KEY_BYTES, "1") + return crypt.RSASigner.from_string(test_jwt.PRIVATE_KEY_BYTES, "1") def test_encode_basic(signer): test_payload = {"test": "value"} - encoded = jwt.encode(signer, test_payload) - header, payload, _, _ = jwt._unverified_decode(encoded) + encoded = jwt_async.encode(signer, test_payload) + header, payload, _, _ = jwt_async._unverified_decode(encoded) assert payload == test_payload assert header == {"typ": "JWT", "alg": "RS256", "kid": signer.key_id} def test_encode_extra_headers(signer): - encoded = jwt.encode(signer, {}, header={"extra": "value"}) - header = jwt.decode_header(encoded) + encoded = jwt_async.encode(signer, {}, header={"extra": "value"}) + header = jwt_async.decode_header(encoded) assert header == { "typ": "JWT", "alg": "RS256", @@ -75,13 +53,13 @@ def test_encode_extra_headers(signer): @pytest.fixture def es256_signer(): - return crypt.ES256Signer.from_string(EC_PRIVATE_KEY_BYTES, "1") + return crypt.ES256Signer.from_string(test_jwt.EC_PRIVATE_KEY_BYTES, "1") def test_encode_basic_es256(es256_signer): test_payload = {"test": "value"} - encoded = jwt.encode(es256_signer, test_payload) - header, payload, _, _ = jwt._unverified_decode(encoded) + encoded = jwt_async.encode(es256_signer, test_payload) + header, payload, _, _ = jwt_async._unverified_decode(encoded) assert payload == test_payload assert header == {"typ": "JWT", "alg": "ES256", "kid": es256_signer.key_id} @@ -106,23 +84,23 @@ def factory(claims=None, key_id=None, use_es256_signer=False): key_id = None if use_es256_signer: - return jwt.encode(es256_signer, payload, key_id=key_id) + return jwt_async.encode(es256_signer, payload, key_id=key_id) else: - return jwt.encode(signer, payload, key_id=key_id) + return jwt_async.encode(signer, payload, key_id=key_id) return factory def test_decode_valid(token_factory): - payload = jwt.decode(token_factory(), certs=PUBLIC_CERT_BYTES) + payload = jwt_async.decode(token_factory(), certs=test_jwt.PUBLIC_CERT_BYTES) assert payload["aud"] == "audience@example.com" assert payload["user"] == "billy bob" assert payload["metadata"]["meta"] == "data" def test_decode_valid_es256(token_factory): - payload = jwt.decode( - token_factory(use_es256_signer=True), certs=EC_PUBLIC_CERT_BYTES + payload = jwt_async.decode( + token_factory(use_es256_signer=True), certs=test_jwt.EC_PUBLIC_CERT_BYTES ) assert payload["aud"] == "audience@example.com" assert payload["user"] == "billy bob" @@ -130,8 +108,10 @@ def test_decode_valid_es256(token_factory): def test_decode_valid_with_audience(token_factory): - payload = jwt.decode( - token_factory(), certs=PUBLIC_CERT_BYTES, audience="audience@example.com" + payload = jwt_async.decode( + token_factory(), + certs=test_jwt.PUBLIC_CERT_BYTES, + audience="audience@example.com", ) assert payload["aud"] == "audience@example.com" assert payload["user"] == "billy bob" @@ -139,7 +119,9 @@ def test_decode_valid_with_audience(token_factory): def test_decode_valid_unverified(token_factory): - payload = jwt.decode(token_factory(), certs=OTHER_CERT_BYTES, verify=False) + payload = jwt_async.decode( + token_factory(), certs=test_jwt.OTHER_CERT_BYTES, verify=False + ) assert payload["aud"] == "audience@example.com" assert payload["user"] == "billy bob" assert payload["metadata"]["meta"] == "data" @@ -147,27 +129,27 @@ def test_decode_valid_unverified(token_factory): def test_decode_bad_token_wrong_number_of_segments(): with pytest.raises(ValueError) as excinfo: - jwt.decode("1.2", PUBLIC_CERT_BYTES) + jwt_async.decode("1.2", test_jwt.PUBLIC_CERT_BYTES) assert excinfo.match(r"Wrong number of segments") def test_decode_bad_token_not_base64(): with pytest.raises((ValueError, TypeError)) as excinfo: - jwt.decode("1.2.3", PUBLIC_CERT_BYTES) + jwt_async.decode("1.2.3", test_jwt.PUBLIC_CERT_BYTES) assert excinfo.match(r"Incorrect padding|more than a multiple of 4") def test_decode_bad_token_not_json(): token = b".".join([base64.urlsafe_b64encode(b"123!")] * 3) with pytest.raises(ValueError) as excinfo: - jwt.decode(token, PUBLIC_CERT_BYTES) + jwt_async.decode(token, test_jwt.PUBLIC_CERT_BYTES) assert excinfo.match(r"Can\'t parse segment") def test_decode_bad_token_no_iat_or_exp(signer): - token = jwt.encode(signer, {"test": "value"}) + token = jwt_async.encode(signer, {"test": "value"}) with pytest.raises(ValueError) as excinfo: - jwt.decode(token, PUBLIC_CERT_BYTES) + jwt_async.decode(token, test_jwt.PUBLIC_CERT_BYTES) assert excinfo.match(r"Token does not contain required claim") @@ -180,7 +162,7 @@ def test_decode_bad_token_too_early(token_factory): } ) with pytest.raises(ValueError) as excinfo: - jwt.decode(token, PUBLIC_CERT_BYTES) + jwt_async.decode(token, test_jwt.PUBLIC_CERT_BYTES) assert excinfo.match(r"Token used too early") @@ -193,7 +175,7 @@ def test_decode_bad_token_expired(token_factory): } ) with pytest.raises(ValueError) as excinfo: - jwt.decode(token, PUBLIC_CERT_BYTES) + jwt_async.decode(token, test_jwt.PUBLIC_CERT_BYTES) assert excinfo.match(r"Token expired") @@ -201,34 +183,34 @@ def test_decode_bad_token_wrong_audience(token_factory): token = token_factory() audience = "audience2@example.com" with pytest.raises(ValueError) as excinfo: - jwt.decode(token, PUBLIC_CERT_BYTES, audience=audience) + jwt_async.decode(token, test_jwt.PUBLIC_CERT_BYTES, audience=audience) assert excinfo.match(r"Token has wrong audience") def test_decode_wrong_cert(token_factory): with pytest.raises(ValueError) as excinfo: - jwt.decode(token_factory(), OTHER_CERT_BYTES) + jwt_async.decode(token_factory(), test_jwt.OTHER_CERT_BYTES) assert excinfo.match(r"Could not verify token signature") def test_decode_multicert_bad_cert(token_factory): - certs = {"1": OTHER_CERT_BYTES, "2": PUBLIC_CERT_BYTES} + certs = {"1": test_jwt.OTHER_CERT_BYTES, "2": test_jwt.PUBLIC_CERT_BYTES} with pytest.raises(ValueError) as excinfo: - jwt.decode(token_factory(), certs) + jwt_async.decode(token_factory(), certs) assert excinfo.match(r"Could not verify token signature") def test_decode_no_cert(token_factory): - certs = {"2": PUBLIC_CERT_BYTES} + certs = {"2": test_jwt.PUBLIC_CERT_BYTES} with pytest.raises(ValueError) as excinfo: - jwt.decode(token_factory(), certs) + jwt_async.decode(token_factory(), certs) assert excinfo.match(r"Certificate for key id 1 not found") def test_decode_no_key_id(token_factory): token = token_factory(key_id=False) - certs = {"2": PUBLIC_CERT_BYTES} - payload = jwt.decode(token, certs) + certs = {"2": test_jwt.PUBLIC_CERT_BYTES} + payload = jwt_async.decode(token, certs) assert payload["user"] == "billy bob" @@ -239,7 +221,7 @@ def test_decode_unknown_alg(): ) with pytest.raises(ValueError) as excinfo: - jwt.decode(token) + jwt_async.decode(token) assert excinfo.match(r"fakealg") @@ -251,14 +233,14 @@ def test_decode_missing_crytography_alg(monkeypatch): ) with pytest.raises(ValueError) as excinfo: - jwt.decode(token) + jwt_async.decode(token) assert excinfo.match(r"cryptography") def test_roundtrip_explicit_key_id(token_factory): token = token_factory(key_id="3") - certs = {"2": OTHER_CERT_BYTES, "3": PUBLIC_CERT_BYTES} - payload = jwt.decode(token, certs) + certs = {"2": test_jwt.OTHER_CERT_BYTES, "3": test_jwt.PUBLIC_CERT_BYTES} + payload = jwt_async.decode(token, certs) assert payload["user"] == "billy bob" @@ -271,7 +253,7 @@ class TestCredentials(object): @pytest.fixture(autouse=True) def credentials_fixture(self, signer): - self.credentials = jwt.Credentials( + self.credentials = jwt_async.Credentials( signer, self.SERVICE_ACCOUNT_EMAIL, self.SERVICE_ACCOUNT_EMAIL, @@ -279,10 +261,10 @@ def credentials_fixture(self, signer): ) def test_from_service_account_info(self): - with open(SERVICE_ACCOUNT_JSON_FILE, "r") as fh: + with open(test_jwt.SERVICE_ACCOUNT_JSON_FILE, "r") as fh: info = json.load(fh) - credentials = jwt.Credentials.from_service_account_info( + credentials = jwt_async.Credentials.from_service_account_info( info, audience=self.AUDIENCE ) @@ -292,9 +274,9 @@ def test_from_service_account_info(self): assert credentials._audience == self.AUDIENCE def test_from_service_account_info_args(self): - info = SERVICE_ACCOUNT_INFO.copy() + info = test_jwt.SERVICE_ACCOUNT_INFO.copy() - credentials = jwt.Credentials.from_service_account_info( + credentials = jwt_async.Credentials.from_service_account_info( info, subject=self.SUBJECT, audience=self.AUDIENCE, @@ -308,10 +290,10 @@ def test_from_service_account_info_args(self): assert credentials._additional_claims == self.ADDITIONAL_CLAIMS def test_from_service_account_file(self): - info = SERVICE_ACCOUNT_INFO.copy() + info = test_jwt.SERVICE_ACCOUNT_INFO.copy() - credentials = jwt.Credentials.from_service_account_file( - SERVICE_ACCOUNT_JSON_FILE, audience=self.AUDIENCE + credentials = jwt_async.Credentials.from_service_account_file( + test_jwt.SERVICE_ACCOUNT_JSON_FILE, audience=self.AUDIENCE ) assert credentials._signer.key_id == info["private_key_id"] @@ -320,10 +302,10 @@ def test_from_service_account_file(self): assert credentials._audience == self.AUDIENCE def test_from_service_account_file_args(self): - info = SERVICE_ACCOUNT_INFO.copy() + info = test_jwt.SERVICE_ACCOUNT_INFO.copy() - credentials = jwt.Credentials.from_service_account_file( - SERVICE_ACCOUNT_JSON_FILE, + credentials = jwt_async.Credentials.from_service_account_file( + test_jwt.SERVICE_ACCOUNT_JSON_FILE, subject=self.SUBJECT, audience=self.AUDIENCE, additional_claims=self.ADDITIONAL_CLAIMS, @@ -339,11 +321,11 @@ def test_from_signing_credentials(self): jwt_from_signing = self.credentials.from_signing_credentials( self.credentials, audience=mock.sentinel.new_audience ) - jwt_from_info = jwt.Credentials.from_service_account_info( - SERVICE_ACCOUNT_INFO, audience=mock.sentinel.new_audience + jwt_from_info = jwt_async.Credentials.from_service_account_info( + test_jwt.SERVICE_ACCOUNT_INFO, audience=mock.sentinel.new_audience ) - assert isinstance(jwt_from_signing, jwt.Credentials) + assert isinstance(jwt_from_signing, jwt_async.Credentials) assert jwt_from_signing._signer.key_id == jwt_from_info._signer.key_id assert jwt_from_signing._issuer == jwt_from_info._issuer assert jwt_from_signing._subject == jwt_from_info._subject @@ -379,16 +361,19 @@ def test_with_quota_project(self): def test_sign_bytes(self): to_sign = b"123" signature = self.credentials.sign_bytes(to_sign) - assert crypt.verify_signature(to_sign, signature, PUBLIC_CERT_BYTES) + assert crypt.verify_signature(to_sign, signature, test_jwt.PUBLIC_CERT_BYTES) def test_signer(self): assert isinstance(self.credentials.signer, crypt.RSASigner) def test_signer_email(self): - assert self.credentials.signer_email == SERVICE_ACCOUNT_INFO["client_email"] + assert ( + self.credentials.signer_email + == test_jwt.SERVICE_ACCOUNT_INFO["client_email"] + ) def _verify_token(self, token): - payload = jwt.decode(token, PUBLIC_CERT_BYTES) + payload = jwt_async.decode(token, test_jwt.PUBLIC_CERT_BYTES) assert payload["iss"] == self.SERVICE_ACCOUNT_EMAIL return payload @@ -443,7 +428,7 @@ class TestOnDemandCredentials(object): @pytest.fixture(autouse=True) def credentials_fixture(self, signer): - self.credentials = jwt.OnDemandCredentials( + self.credentials = jwt_async.OnDemandCredentials( signer, self.SERVICE_ACCOUNT_EMAIL, self.SERVICE_ACCOUNT_EMAIL, @@ -451,19 +436,19 @@ def credentials_fixture(self, signer): ) def test_from_service_account_info(self): - with open(SERVICE_ACCOUNT_JSON_FILE, "r") as fh: + with open(test_jwt.SERVICE_ACCOUNT_JSON_FILE, "r") as fh: info = json.load(fh) - credentials = jwt.OnDemandCredentials.from_service_account_info(info) + credentials = jwt_async.OnDemandCredentials.from_service_account_info(info) assert credentials._signer.key_id == info["private_key_id"] assert credentials._issuer == info["client_email"] assert credentials._subject == info["client_email"] def test_from_service_account_info_args(self): - info = SERVICE_ACCOUNT_INFO.copy() + info = test_jwt.SERVICE_ACCOUNT_INFO.copy() - credentials = jwt.OnDemandCredentials.from_service_account_info( + credentials = jwt_async.OnDemandCredentials.from_service_account_info( info, subject=self.SUBJECT, additional_claims=self.ADDITIONAL_CLAIMS ) @@ -473,10 +458,10 @@ def test_from_service_account_info_args(self): assert credentials._additional_claims == self.ADDITIONAL_CLAIMS def test_from_service_account_file(self): - info = SERVICE_ACCOUNT_INFO.copy() + info = test_jwt.SERVICE_ACCOUNT_INFO.copy() - credentials = jwt.OnDemandCredentials.from_service_account_file( - SERVICE_ACCOUNT_JSON_FILE + credentials = jwt_async.OnDemandCredentials.from_service_account_file( + test_jwt.SERVICE_ACCOUNT_JSON_FILE ) assert credentials._signer.key_id == info["private_key_id"] @@ -484,10 +469,10 @@ def test_from_service_account_file(self): assert credentials._subject == info["client_email"] def test_from_service_account_file_args(self): - info = SERVICE_ACCOUNT_INFO.copy() + info = test_jwt.SERVICE_ACCOUNT_INFO.copy() - credentials = jwt.OnDemandCredentials.from_service_account_file( - SERVICE_ACCOUNT_JSON_FILE, + credentials = jwt_async.OnDemandCredentials.from_service_account_file( + test_jwt.SERVICE_ACCOUNT_JSON_FILE, subject=self.SUBJECT, additional_claims=self.ADDITIONAL_CLAIMS, ) @@ -499,11 +484,11 @@ def test_from_service_account_file_args(self): def test_from_signing_credentials(self): jwt_from_signing = self.credentials.from_signing_credentials(self.credentials) - jwt_from_info = jwt.OnDemandCredentials.from_service_account_info( - SERVICE_ACCOUNT_INFO + jwt_from_info = jwt_async.OnDemandCredentials.from_service_account_info( + test_jwt.SERVICE_ACCOUNT_INFO ) - assert isinstance(jwt_from_signing, jwt.OnDemandCredentials) + assert isinstance(jwt_from_signing, jwt_async.OnDemandCredentials) assert jwt_from_signing._signer.key_id == jwt_from_info._signer.key_id assert jwt_from_signing._issuer == jwt_from_info._issuer assert jwt_from_signing._subject == jwt_from_info._subject @@ -536,16 +521,19 @@ def test_with_quota_project(self): def test_sign_bytes(self): to_sign = b"123" signature = self.credentials.sign_bytes(to_sign) - assert crypt.verify_signature(to_sign, signature, PUBLIC_CERT_BYTES) + assert crypt.verify_signature(to_sign, signature, test_jwt.PUBLIC_CERT_BYTES) def test_signer(self): assert isinstance(self.credentials.signer, crypt.RSASigner) def test_signer_email(self): - assert self.credentials.signer_email == SERVICE_ACCOUNT_INFO["client_email"] + assert ( + self.credentials.signer_email + == test_jwt.SERVICE_ACCOUNT_INFO["client_email"] + ) def _verify_token(self, token): - payload = jwt.decode(token, PUBLIC_CERT_BYTES) + payload = jwt_async.decode(token, test_jwt.PUBLIC_CERT_BYTES) assert payload["iss"] == self.SERVICE_ACCOUNT_EMAIL return payload diff --git a/tests_async/transport/async_compliance.py b/tests_async/transport/async_compliance.py index 4bb027399..9c4b173c2 100644 --- a/tests_async/transport/async_compliance.py +++ b/tests_async/transport/async_compliance.py @@ -19,11 +19,8 @@ from pytest_localserver.http import WSGIServer from six.moves import http_client - from google.auth import exceptions - -# .invalid will never resolve, see https://tools.ietf.org/html/rfc2606 -NXDOMAIN = "test.invalid" +from tests.transport import compliance class RequestResponseTests(object): @@ -133,4 +130,4 @@ async def test_connection_error(self): request = self.make_request() with pytest.raises(exceptions.TransportError): - await request(url="http://{}".format(NXDOMAIN), method="GET") + await request(url="http://{}".format(compliance.NXDOMAIN), method="GET") From 9ec8277b411ff2d7eac806008187fe8760c9494c Mon Sep 17 00:00:00 2001 From: AniBadde Date: Wed, 29 Jul 2020 11:31:32 -0500 Subject: [PATCH 15/20] fix: async docstring --- google/oauth2/_client_async.py | 2 +- google/oauth2/credentials_async.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/google/oauth2/_client_async.py b/google/oauth2/_client_async.py index 99bae7f55..9f722a1e4 100644 --- a/google/oauth2/_client_async.py +++ b/google/oauth2/_client_async.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""OAuth 2.0 client. +"""OAuth 2.0 async client. This is a client for interacting with an OAuth 2.0 authorization server's token endpoint. diff --git a/google/oauth2/credentials_async.py b/google/oauth2/credentials_async.py index 092aa5781..2081a0be2 100644 --- a/google/oauth2/credentials_async.py +++ b/google/oauth2/credentials_async.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""OAuth 2.0 Credentials. +"""OAuth 2.0 Async Credentials. This module provides credentials based on OAuth 2.0 access and refresh tokens. These credentials usually access resources on behalf of a user (resource From 8f254de065054888bc5e997f164c288d027ecc01 Mon Sep 17 00:00:00 2001 From: AniBadde Date: Wed, 29 Jul 2020 16:46:37 -0500 Subject: [PATCH 16/20] refactoring --- google/auth/_default_async.py | 45 ++------ google/auth/_oauth2client_async.py | 26 +++-- google/auth/jwt_async.py | 167 +---------------------------- 3 files changed, 26 insertions(+), 212 deletions(-) diff --git a/google/auth/_default_async.py b/google/auth/_default_async.py index 2c367e70a..ce85e3220 100644 --- a/google/auth/_default_async.py +++ b/google/auth/_default_async.py @@ -24,10 +24,9 @@ import six -from google.auth import _default as default +from google.auth import _default from google.auth import environment_vars from google.auth import exceptions -import google.auth.transport._http_client def _warn_about_problematic_credentials(credentials): @@ -40,7 +39,7 @@ def _warn_about_problematic_credentials(credentials): from google.auth import _cloud_sdk if credentials.client_id == _cloud_sdk.CLOUD_SDK_CLIENT_ID: - warnings.warn(default._CLOUD_SDK_CREDENTIALS_WARNING) + warnings.warn(_default._CLOUD_SDK_CREDENTIALS_WARNING) def load_credentials_from_file(filename, scopes=None, quota_project_id=None): @@ -84,7 +83,7 @@ def load_credentials_from_file(filename, scopes=None, quota_project_id=None): # credentials file or an authorized user credentials file. credential_type = info.get("type") - if credential_type == default._AUTHORIZED_USER_TYPE: + if credential_type == _default._AUTHORIZED_USER_TYPE: from google.oauth2 import credentials_async as credentials try: @@ -99,7 +98,7 @@ def load_credentials_from_file(filename, scopes=None, quota_project_id=None): _warn_about_problematic_credentials(credentials) return credentials, None - elif credential_type == default._SERVICE_ACCOUNT_TYPE: + elif credential_type == _default._SERVICE_ACCOUNT_TYPE: from google.oauth2 import service_account_async as service_account try: @@ -116,7 +115,7 @@ def load_credentials_from_file(filename, scopes=None, quota_project_id=None): raise exceptions.DefaultCredentialsError( "The file {file} does not have a valid type. " "Type is {type}, expected one of {valid_types}.".format( - file=filename, type=credential_type, valid_types=default._VALID_TYPES + file=filename, type=credential_type, valid_types=_default._VALID_TYPES ) ) @@ -159,17 +158,8 @@ def _get_gae_credentials(): """Gets Google App Engine App Identity credentials and project ID.""" # While this library is normally bundled with app_engine, there are # some cases where it's not available, so we tolerate ImportError. - try: - import google.auth.app_engine as app_engine - except ImportError: - return None, None - try: - credentials = app_engine.Credentials() - project_id = app_engine.get_project_id() - return credentials, project_id - except EnvironmentError: - return None, None + return _default._get_gae_credentials() def _get_gce_credentials(request=None): @@ -181,25 +171,8 @@ def _get_gce_credentials(request=None): # While this library is normally bundled with compute_engine, there are # some cases where it's not available, so we tolerate ImportError. - try: - from google.auth import compute_engine - from google.auth.compute_engine import _metadata - except ImportError: - return None, None - - if request is None: - request = google.auth.transport._http_client.Request() - if _metadata.ping(request=request): - # Get the project ID. - try: - project_id = _metadata.get_project_id(request=request) - except exceptions.TransportError: - project_id = None - - return compute_engine.Credentials(), project_id - else: - return None, None + return _default._get_gce_credentials(request) def default_async(scopes=None, request=None, quota_project_id=None): @@ -296,7 +269,7 @@ def default_async(scopes=None, request=None, quota_project_id=None): ).with_quota_project(quota_project_id) effective_project_id = explicit_project_id or project_id if not effective_project_id: - default._LOGGER.warning( + _default._LOGGER.warning( "No project ID could be determined. Consider running " "`gcloud config set project` or setting the %s " "environment variable", @@ -304,4 +277,4 @@ def default_async(scopes=None, request=None, quota_project_id=None): ) return credentials, effective_project_id - raise exceptions.DefaultCredentialsError(default._HELP_MESSAGE) + raise exceptions.DefaultCredentialsError(_default._HELP_MESSAGE) diff --git a/google/auth/_oauth2client_async.py b/google/auth/_oauth2client_async.py index 870e55019..70e399738 100644 --- a/google/auth/_oauth2client_async.py +++ b/google/auth/_oauth2client_async.py @@ -23,11 +23,9 @@ import six -from google.auth import _helpers from google.auth import _oauth2client import google.auth.app_engine import google.auth.compute_engine -import google.oauth2.credentials import google.oauth2.service_account_async @@ -58,7 +56,7 @@ def _convert_oauth2_credentials(credentials): Returns: google.oauth2.credentials_async.Credentials: The converted credentials. """ - new_credentials = google.oauth2.credentials.Credentials( + new_credentials = google.oauth2.credentials_async.Credentials( token=credentials.access_token, refresh_token=credentials.refresh_token, token_uri=credentials.token_uri, @@ -101,9 +99,7 @@ def _convert_gce_app_assertion_credentials(credentials): Returns: google.oauth2.service_account_async.Credentials: The converted credentials. """ - return google.auth.compute_engine.Credentials( - service_account_email=credentials.service_account_email - ) + return _oauth2client._convert_gce_app_assertion_credentials(credentials) def _convert_appengine_app_assertion_credentials(credentials): @@ -117,10 +113,7 @@ def _convert_appengine_app_assertion_credentials(credentials): google.oauth2.service_account_async.Credentials: The converted credentials. """ # pylint: disable=invalid-name - return google.auth.app_engine.Credentials( - scopes=_helpers.string_to_scopes(credentials.scope), - service_account_id=credentials.service_account_id, - ) + return _oauth2client._convert_appengine_app_assertion_credentials(credentials) _CLASS_CONVERSION_MAP = { @@ -142,13 +135,13 @@ def convert(credentials): This class converts: - - :class:`oauth2client.client_async.OAuth2Credentials` to + - :class:`oauth2client.client.OAuth2Credentials` to :class:`google.oauth2.credentials_async.Credentials`. - - :class:`oauth2client.client_async.GoogleCredentials` to + - :class:`oauth2client.client.GoogleCredentials` to :class:`google.oauth2.credentials_async.Credentials`. - - :class:`oauth2client.service_account_async.ServiceAccountCredentials` to + - :class:`oauth2client.service_account.ServiceAccountCredentials` to :class:`google.oauth2.service_account_async.Credentials`. - - :class:`oauth2client.service_account_async._JWTAccessCredentials` to + - :class:`oauth2client.service_account._JWTAccessCredentials` to :class:`google.oauth2.service_account_async.Credentials`. - :class:`oauth2client.contrib.gce.AppAssertionCredentials` to :class:`google.auth.compute_engine.Credentials`. @@ -162,6 +155,11 @@ def convert(credentials): ValueError: If the credentials could not be converted. """ + """ + Not able to inherit from the sync _oauth2client file (calling _oauth2client.convert), + and therefore reproduced here. + """ + credentials_class = type(credentials) try: diff --git a/google/auth/jwt_async.py b/google/auth/jwt_async.py index a5293d73f..ffd8a9bb1 100644 --- a/google/auth/jwt_async.py +++ b/google/auth/jwt_async.py @@ -40,26 +40,9 @@ """ -try: - from collections.abc import Mapping -# Python 2.7 compatibility -except ImportError: # pragma: NO COVER - from collections import Mapping - -import json - -import six - import google.auth -from google.auth import _helpers -from google.auth import crypt from google.auth import jwt -try: - from google.auth.crypt import es256 -except ImportError: # pragma: NO COVER - es256 = None - def encode(signer, payload, header=None, key_id=None): """Make a signed JWT. @@ -75,42 +58,7 @@ def encode(signer, payload, header=None, key_id=None): Returns: bytes: The encoded JWT. """ - if header is None: - header = {} - - if key_id is None: - key_id = signer.key_id - - header.update({"typ": "JWT"}) - - if es256 is not None and isinstance(signer, es256.ES256Signer): - header.update({"alg": "ES256"}) - else: - header.update({"alg": "RS256"}) - - if key_id is not None: - header["kid"] = key_id - - segments = [ - _helpers.unpadded_urlsafe_b64encode(json.dumps(header).encode("utf-8")), - _helpers.unpadded_urlsafe_b64encode(json.dumps(payload).encode("utf-8")), - ] - - signing_input = b".".join(segments) - signature = signer.sign(signing_input) - segments.append(_helpers.unpadded_urlsafe_b64encode(signature)) - - return b".".join(segments) - - -def _decode_jwt_segment(encoded_section): - """Decodes a single JWT segment.""" - section_bytes = _helpers.padded_urlsafe_b64decode(encoded_section) - try: - return json.loads(section_bytes.decode("utf-8")) - except ValueError as caught_exc: - new_exc = ValueError("Can't parse segment: {0}".format(section_bytes)) - six.raise_from(new_exc, caught_exc) + return jwt.encode(signer, payload, header, key_id) def _unverified_decode(token): @@ -126,20 +74,7 @@ def _unverified_decode(token): Raises: ValueError: if there are an incorrect amount of segments in the token. """ - token = _helpers.to_bytes(token) - - if token.count(b".") != 2: - raise ValueError("Wrong number of segments in token: {0}".format(token)) - - encoded_header, encoded_payload, signature = token.split(b".") - signed_section = encoded_header + b"." + encoded_payload - signature = _helpers.padded_urlsafe_b64decode(signature) - - # Parse segments - header = _decode_jwt_segment(encoded_header) - payload = _decode_jwt_segment(encoded_payload) - - return header, payload, signed_section, signature + return jwt._unverified_decode(token) def decode_header(token): @@ -155,42 +90,7 @@ def decode_header(token): Returns: Mapping: The decoded JWT header. """ - header, _, _, _ = _unverified_decode(token) - return header - - -def _verify_iat_and_exp(payload): - """Verifies the ``iat`` (Issued At) and ``exp`` (Expires) claims in a token - payload. - - Args: - payload (Mapping[str, str]): The JWT payload. - - Raises: - ValueError: if any checks failed. - """ - now = _helpers.datetime_to_secs(_helpers.utcnow()) - - # Make sure the iat and exp claims are present. - for key in ("iat", "exp"): - if key not in payload: - raise ValueError("Token does not contain required claim {}".format(key)) - - # Make sure the token wasn't issued in the future. - iat = payload["iat"] - # Err on the side of accepting a token that is slightly early to account - # for clock skew. - earliest = iat - _helpers.CLOCK_SKEW_SECS - if now < earliest: - raise ValueError("Token used too early, {} < {}".format(now, iat)) - - # Make sure the token wasn't issued in the past. - exp = payload["exp"] - # Err on the side of accepting a token that is slightly out of date - # to account for clow skew. - latest = exp + _helpers.CLOCK_SKEW_SECS - if latest < now: - raise ValueError("Token expired, {} < {}".format(latest, now)) + return jwt.decode_header(token) def decode(token, certs=None, verify=True, audience=None): @@ -215,65 +115,8 @@ def decode(token, certs=None, verify=True, audience=None): Raises: ValueError: if any verification checks failed. """ - header, payload, signed_section, signature = _unverified_decode(token) - - if not verify: - return payload - - # Pluck the key id and algorithm from the header and make sure we have - # a verifier that can support it. - key_alg = header.get("alg") - key_id = header.get("kid") - - try: - verifier_cls = jwt._ALGORITHM_TO_VERIFIER_CLASS[key_alg] - except KeyError as exc: - if key_alg in jwt._CRYPTOGRAPHY_BASED_ALGORITHMS: - six.raise_from( - ValueError( - "The key algorithm {} requires the cryptography package " - "to be installed.".format(key_alg) - ), - exc, - ) - else: - six.raise_from( - ValueError("Unsupported signature algorithm {}".format(key_alg)), exc - ) - - # If certs is specified as a dictionary of key IDs to certificates, then - # use the certificate identified by the key ID in the token header. - if isinstance(certs, Mapping): - if key_id: - if key_id not in certs: - raise ValueError("Certificate for key id {} not found.".format(key_id)) - certs_to_check = [certs[key_id]] - # If there's no key id in the header, check against all of the certs. - else: - certs_to_check = certs.values() - else: - certs_to_check = certs - - # Verify that the signature matches the message. - if not crypt.verify_signature( - signed_section, signature, certs_to_check, verifier_cls - ): - raise ValueError("Could not verify token signature.") - - # Verify the issued at and created times in the payload. - _verify_iat_and_exp(payload) - - # Check audience. - if audience is not None: - claim_audience = payload.get("aud") - if audience != claim_audience: - raise ValueError( - "Token has wrong audience {}, expected {}".format( - claim_audience, audience - ) - ) - - return payload + + return jwt.decode(token, certs, verify, audience) class Credentials( From aa04dc9bb3ff6e0d6449c50f64ba8f28f9e8dd47 Mon Sep 17 00:00:00 2001 From: AniBadde Date: Wed, 29 Jul 2020 17:10:56 -0500 Subject: [PATCH 17/20] fix: refactoring --- docs/reference/google.auth.credentials_async.rst | 7 +++++++ docs/reference/google.auth.jwt_async.rst | 7 +++++++ docs/reference/google.auth.rst | 2 ++ docs/reference/google.auth.transport.aiohttp_requests.rst | 7 +++++++ docs/reference/google.auth.transport.mtls.rst | 7 +++++++ docs/reference/google.auth.transport.rst | 1 + docs/reference/google.oauth2.credentials_async.rst | 7 +++++++ docs/reference/google.oauth2.rst | 2 ++ docs/reference/google.oauth2.service_account_async.rst | 7 +++++++ 9 files changed, 47 insertions(+) create mode 100644 docs/reference/google.auth.credentials_async.rst create mode 100644 docs/reference/google.auth.jwt_async.rst create mode 100644 docs/reference/google.auth.transport.aiohttp_requests.rst create mode 100644 docs/reference/google.auth.transport.mtls.rst create mode 100644 docs/reference/google.oauth2.credentials_async.rst create mode 100644 docs/reference/google.oauth2.service_account_async.rst diff --git a/docs/reference/google.auth.credentials_async.rst b/docs/reference/google.auth.credentials_async.rst new file mode 100644 index 000000000..4e4641e0f --- /dev/null +++ b/docs/reference/google.auth.credentials_async.rst @@ -0,0 +1,7 @@ +google.auth.credentials\_async module +===================================== + +.. automodule:: google.auth.credentials_async + :members: + :inherited-members: + :show-inheritance: diff --git a/docs/reference/google.auth.jwt_async.rst b/docs/reference/google.auth.jwt_async.rst new file mode 100644 index 000000000..4e56a6ea3 --- /dev/null +++ b/docs/reference/google.auth.jwt_async.rst @@ -0,0 +1,7 @@ +google.auth.jwt\_async module +============================= + +.. automodule:: google.auth.jwt_async + :members: + :inherited-members: + :show-inheritance: diff --git a/docs/reference/google.auth.rst b/docs/reference/google.auth.rst index cfcf70357..2f6fe1454 100644 --- a/docs/reference/google.auth.rst +++ b/docs/reference/google.auth.rst @@ -24,8 +24,10 @@ Submodules google.auth.app_engine google.auth.credentials + google.auth.credentials_async google.auth.environment_vars google.auth.exceptions google.auth.iam google.auth.impersonated_credentials google.auth.jwt + google.auth.jwt_async diff --git a/docs/reference/google.auth.transport.aiohttp_requests.rst b/docs/reference/google.auth.transport.aiohttp_requests.rst new file mode 100644 index 000000000..bc3e74381 --- /dev/null +++ b/docs/reference/google.auth.transport.aiohttp_requests.rst @@ -0,0 +1,7 @@ +google.auth.transport.aiohttp\_requests module +============================================== + +.. automodule:: google.auth.transport.aiohttp_requests + :members: + :inherited-members: + :show-inheritance: diff --git a/docs/reference/google.auth.transport.mtls.rst b/docs/reference/google.auth.transport.mtls.rst new file mode 100644 index 000000000..11b50e23c --- /dev/null +++ b/docs/reference/google.auth.transport.mtls.rst @@ -0,0 +1,7 @@ +google.auth.transport.mtls module +================================= + +.. automodule:: google.auth.transport.mtls + :members: + :inherited-members: + :show-inheritance: diff --git a/docs/reference/google.auth.transport.rst b/docs/reference/google.auth.transport.rst index 89218632b..eba29d037 100644 --- a/docs/reference/google.auth.transport.rst +++ b/docs/reference/google.auth.transport.rst @@ -12,6 +12,7 @@ Submodules .. toctree:: :maxdepth: 4 + google.auth.transport.aiohttp_requests google.auth.transport.grpc google.auth.transport.mtls google.auth.transport.requests diff --git a/docs/reference/google.oauth2.credentials_async.rst b/docs/reference/google.oauth2.credentials_async.rst new file mode 100644 index 000000000..20cb6b684 --- /dev/null +++ b/docs/reference/google.oauth2.credentials_async.rst @@ -0,0 +1,7 @@ +google.oauth2.credentials\_async module +======================================= + +.. automodule:: google.oauth2.credentials_async + :members: + :inherited-members: + :show-inheritance: diff --git a/docs/reference/google.oauth2.rst b/docs/reference/google.oauth2.rst index 1ac9c7320..75955187a 100644 --- a/docs/reference/google.oauth2.rst +++ b/docs/reference/google.oauth2.rst @@ -13,5 +13,7 @@ Submodules :maxdepth: 4 google.oauth2.credentials + google.oauth2.credentials_async google.oauth2.id_token google.oauth2.service_account + google.oauth2.service_account_async diff --git a/docs/reference/google.oauth2.service_account_async.rst b/docs/reference/google.oauth2.service_account_async.rst new file mode 100644 index 000000000..c48c3e248 --- /dev/null +++ b/docs/reference/google.oauth2.service_account_async.rst @@ -0,0 +1,7 @@ +google.oauth2.service\_account\_async module +============================================ + +.. automodule:: google.oauth2.service_account_async + :members: + :inherited-members: + :show-inheritance: From 0c4c3b6b62cdf46ee40bc11c8d34ed3d58fd4cb7 Mon Sep 17 00:00:00 2001 From: AniBadde Date: Thu, 30 Jul 2020 11:23:50 -0500 Subject: [PATCH 18/20] fix: first round of comments, refactoring and test duplication changes --- google/auth/_default_async.py | 6 +- google/auth/jwt_async.py | 32 --- google/oauth2/_client_async.py | 7 - ...est__default.py => test__default_async.py} | 0 ...2client.py => test__oauth2client_async.py} | 0 ...edentials.py => test_credentials_async.py} | 0 .../{test_jwt.py => test_jwt_async.py} | 215 ------------------ 7 files changed, 1 insertion(+), 259 deletions(-) rename tests_async/{test__default.py => test__default_async.py} (100%) rename tests_async/{test__oauth2client.py => test__oauth2client_async.py} (100%) rename tests_async/{test_credentials.py => test_credentials_async.py} (100%) rename tests_async/{test_jwt.py => test_jwt_async.py} (65%) diff --git a/google/auth/_default_async.py b/google/auth/_default_async.py index ce85e3220..1ceeb19e6 100644 --- a/google/auth/_default_async.py +++ b/google/auth/_default_async.py @@ -20,7 +20,6 @@ import io import json import os -import warnings import six @@ -36,10 +35,7 @@ def _warn_about_problematic_credentials(credentials): are problematic because they may not have APIs enabled and have limited quota. If this is the case, warn about it. """ - from google.auth import _cloud_sdk - - if credentials.client_id == _cloud_sdk.CLOUD_SDK_CLIENT_ID: - warnings.warn(_default._CLOUD_SDK_CREDENTIALS_WARNING) + return _default._warn_about_problematic_credentials(credentials) def load_credentials_from_file(filename, scopes=None, quota_project_id=None): diff --git a/google/auth/jwt_async.py b/google/auth/jwt_async.py index ffd8a9bb1..daa5e3ee9 100644 --- a/google/auth/jwt_async.py +++ b/google/auth/jwt_async.py @@ -61,38 +61,6 @@ def encode(signer, payload, header=None, key_id=None): return jwt.encode(signer, payload, header, key_id) -def _unverified_decode(token): - """Decodes a token and does no verification. - - Args: - token (Union[str, bytes]): The encoded JWT. - - Returns: - Tuple[str, str, str, str]: header, payload, signed_section, and - signature. - - Raises: - ValueError: if there are an incorrect amount of segments in the token. - """ - return jwt._unverified_decode(token) - - -def decode_header(token): - """Return the decoded header of a token. - - No verification is done. This is useful to extract the key id from - the header in order to acquire the appropriate certificate to verify - the token. - - Args: - token (Union[str, bytes]): the encoded JWT. - - Returns: - Mapping: The decoded JWT header. - """ - return jwt.decode_header(token) - - def decode(token, certs=None, verify=True, audience=None): """Decode and verify a JWT. diff --git a/google/oauth2/_client_async.py b/google/oauth2/_client_async.py index 9f722a1e4..a6cc3b292 100644 --- a/google/oauth2/_client_async.py +++ b/google/oauth2/_client_async.py @@ -104,12 +104,6 @@ async def _token_endpoint_request(request, token_uri, body): method="POST", url=token_uri, headers=headers, body=body ) - """ - except exceptions.TransportError as caught_exc: - new_exc = exceptions.RefreshError(caught_exc) - six.raise_from(new_exc, caught_exc) - """ - response_body1 = await response.data.read() response_body = ( @@ -117,7 +111,6 @@ async def _token_endpoint_request(request, token_uri, body): if hasattr(response_body1, "decode") else response_body1 ) - # CHANGE TO READ TO END OF STREAM response_data = json.loads(response_body) diff --git a/tests_async/test__default.py b/tests_async/test__default_async.py similarity index 100% rename from tests_async/test__default.py rename to tests_async/test__default_async.py diff --git a/tests_async/test__oauth2client.py b/tests_async/test__oauth2client_async.py similarity index 100% rename from tests_async/test__oauth2client.py rename to tests_async/test__oauth2client_async.py diff --git a/tests_async/test_credentials.py b/tests_async/test_credentials_async.py similarity index 100% rename from tests_async/test_credentials.py rename to tests_async/test_credentials_async.py diff --git a/tests_async/test_jwt.py b/tests_async/test_jwt_async.py similarity index 65% rename from tests_async/test_jwt.py rename to tests_async/test_jwt_async.py index aa9f39f23..b5a499027 100644 --- a/tests_async/test_jwt.py +++ b/tests_async/test_jwt_async.py @@ -12,17 +12,14 @@ # See the License for the specific language governing permissions and # limitations under the License. -import base64 import datetime import json import mock import pytest -from google.auth import _helpers from google.auth import crypt from google.auth import exceptions -from google.auth import jwt from google.auth import jwt_async from tests import test_jwt @@ -32,218 +29,6 @@ def signer(): return crypt.RSASigner.from_string(test_jwt.PRIVATE_KEY_BYTES, "1") -def test_encode_basic(signer): - test_payload = {"test": "value"} - encoded = jwt_async.encode(signer, test_payload) - header, payload, _, _ = jwt_async._unverified_decode(encoded) - assert payload == test_payload - assert header == {"typ": "JWT", "alg": "RS256", "kid": signer.key_id} - - -def test_encode_extra_headers(signer): - encoded = jwt_async.encode(signer, {}, header={"extra": "value"}) - header = jwt_async.decode_header(encoded) - assert header == { - "typ": "JWT", - "alg": "RS256", - "kid": signer.key_id, - "extra": "value", - } - - -@pytest.fixture -def es256_signer(): - return crypt.ES256Signer.from_string(test_jwt.EC_PRIVATE_KEY_BYTES, "1") - - -def test_encode_basic_es256(es256_signer): - test_payload = {"test": "value"} - encoded = jwt_async.encode(es256_signer, test_payload) - header, payload, _, _ = jwt_async._unverified_decode(encoded) - assert payload == test_payload - assert header == {"typ": "JWT", "alg": "ES256", "kid": es256_signer.key_id} - - -@pytest.fixture -def token_factory(signer, es256_signer): - def factory(claims=None, key_id=None, use_es256_signer=False): - now = _helpers.datetime_to_secs(_helpers.utcnow()) - payload = { - "aud": "audience@example.com", - "iat": now, - "exp": now + 300, - "user": "billy bob", - "metadata": {"meta": "data"}, - } - payload.update(claims or {}) - - # False is specified to remove the signer's key id for testing - # headers without key ids. - if key_id is False: - signer._key_id = None - key_id = None - - if use_es256_signer: - return jwt_async.encode(es256_signer, payload, key_id=key_id) - else: - return jwt_async.encode(signer, payload, key_id=key_id) - - return factory - - -def test_decode_valid(token_factory): - payload = jwt_async.decode(token_factory(), certs=test_jwt.PUBLIC_CERT_BYTES) - assert payload["aud"] == "audience@example.com" - assert payload["user"] == "billy bob" - assert payload["metadata"]["meta"] == "data" - - -def test_decode_valid_es256(token_factory): - payload = jwt_async.decode( - token_factory(use_es256_signer=True), certs=test_jwt.EC_PUBLIC_CERT_BYTES - ) - assert payload["aud"] == "audience@example.com" - assert payload["user"] == "billy bob" - assert payload["metadata"]["meta"] == "data" - - -def test_decode_valid_with_audience(token_factory): - payload = jwt_async.decode( - token_factory(), - certs=test_jwt.PUBLIC_CERT_BYTES, - audience="audience@example.com", - ) - assert payload["aud"] == "audience@example.com" - assert payload["user"] == "billy bob" - assert payload["metadata"]["meta"] == "data" - - -def test_decode_valid_unverified(token_factory): - payload = jwt_async.decode( - token_factory(), certs=test_jwt.OTHER_CERT_BYTES, verify=False - ) - assert payload["aud"] == "audience@example.com" - assert payload["user"] == "billy bob" - assert payload["metadata"]["meta"] == "data" - - -def test_decode_bad_token_wrong_number_of_segments(): - with pytest.raises(ValueError) as excinfo: - jwt_async.decode("1.2", test_jwt.PUBLIC_CERT_BYTES) - assert excinfo.match(r"Wrong number of segments") - - -def test_decode_bad_token_not_base64(): - with pytest.raises((ValueError, TypeError)) as excinfo: - jwt_async.decode("1.2.3", test_jwt.PUBLIC_CERT_BYTES) - assert excinfo.match(r"Incorrect padding|more than a multiple of 4") - - -def test_decode_bad_token_not_json(): - token = b".".join([base64.urlsafe_b64encode(b"123!")] * 3) - with pytest.raises(ValueError) as excinfo: - jwt_async.decode(token, test_jwt.PUBLIC_CERT_BYTES) - assert excinfo.match(r"Can\'t parse segment") - - -def test_decode_bad_token_no_iat_or_exp(signer): - token = jwt_async.encode(signer, {"test": "value"}) - with pytest.raises(ValueError) as excinfo: - jwt_async.decode(token, test_jwt.PUBLIC_CERT_BYTES) - assert excinfo.match(r"Token does not contain required claim") - - -def test_decode_bad_token_too_early(token_factory): - token = token_factory( - claims={ - "iat": _helpers.datetime_to_secs( - _helpers.utcnow() + datetime.timedelta(hours=1) - ) - } - ) - with pytest.raises(ValueError) as excinfo: - jwt_async.decode(token, test_jwt.PUBLIC_CERT_BYTES) - assert excinfo.match(r"Token used too early") - - -def test_decode_bad_token_expired(token_factory): - token = token_factory( - claims={ - "exp": _helpers.datetime_to_secs( - _helpers.utcnow() - datetime.timedelta(hours=1) - ) - } - ) - with pytest.raises(ValueError) as excinfo: - jwt_async.decode(token, test_jwt.PUBLIC_CERT_BYTES) - assert excinfo.match(r"Token expired") - - -def test_decode_bad_token_wrong_audience(token_factory): - token = token_factory() - audience = "audience2@example.com" - with pytest.raises(ValueError) as excinfo: - jwt_async.decode(token, test_jwt.PUBLIC_CERT_BYTES, audience=audience) - assert excinfo.match(r"Token has wrong audience") - - -def test_decode_wrong_cert(token_factory): - with pytest.raises(ValueError) as excinfo: - jwt_async.decode(token_factory(), test_jwt.OTHER_CERT_BYTES) - assert excinfo.match(r"Could not verify token signature") - - -def test_decode_multicert_bad_cert(token_factory): - certs = {"1": test_jwt.OTHER_CERT_BYTES, "2": test_jwt.PUBLIC_CERT_BYTES} - with pytest.raises(ValueError) as excinfo: - jwt_async.decode(token_factory(), certs) - assert excinfo.match(r"Could not verify token signature") - - -def test_decode_no_cert(token_factory): - certs = {"2": test_jwt.PUBLIC_CERT_BYTES} - with pytest.raises(ValueError) as excinfo: - jwt_async.decode(token_factory(), certs) - assert excinfo.match(r"Certificate for key id 1 not found") - - -def test_decode_no_key_id(token_factory): - token = token_factory(key_id=False) - certs = {"2": test_jwt.PUBLIC_CERT_BYTES} - payload = jwt_async.decode(token, certs) - assert payload["user"] == "billy bob" - - -def test_decode_unknown_alg(): - headers = json.dumps({u"kid": u"1", u"alg": u"fakealg"}) - token = b".".join( - map(lambda seg: base64.b64encode(seg.encode("utf-8")), [headers, u"{}", u"sig"]) - ) - - with pytest.raises(ValueError) as excinfo: - jwt_async.decode(token) - assert excinfo.match(r"fakealg") - - -def test_decode_missing_crytography_alg(monkeypatch): - monkeypatch.delitem(jwt._ALGORITHM_TO_VERIFIER_CLASS, "ES256") - headers = json.dumps({u"kid": u"1", u"alg": u"ES256"}) - token = b".".join( - map(lambda seg: base64.b64encode(seg.encode("utf-8")), [headers, u"{}", u"sig"]) - ) - - with pytest.raises(ValueError) as excinfo: - jwt_async.decode(token) - assert excinfo.match(r"cryptography") - - -def test_roundtrip_explicit_key_id(token_factory): - token = token_factory(key_id="3") - certs = {"2": test_jwt.OTHER_CERT_BYTES, "3": test_jwt.PUBLIC_CERT_BYTES} - payload = jwt_async.decode(token, certs) - assert payload["user"] == "billy bob" - - class TestCredentials(object): SERVICE_ACCOUNT_EMAIL = "service-account@example.com" SUBJECT = "subject" From 92175f27859ccaf68a7b1ec444bb26ed6580a165 Mon Sep 17 00:00:00 2001 From: AniBadde Date: Thu, 30 Jul 2020 11:40:14 -0500 Subject: [PATCH 19/20] fix: removed duplication in _default_async --- google/auth/_default_async.py | 12 +----------- 1 file changed, 1 insertion(+), 11 deletions(-) diff --git a/google/auth/_default_async.py b/google/auth/_default_async.py index 1ceeb19e6..15f405ae6 100644 --- a/google/auth/_default_async.py +++ b/google/auth/_default_async.py @@ -28,16 +28,6 @@ from google.auth import exceptions -def _warn_about_problematic_credentials(credentials): - """Determines if the credentials are problematic. - - Credentials from the Cloud SDK that are associated with Cloud SDK's project - are problematic because they may not have APIs enabled and have limited - quota. If this is the case, warn about it. - """ - return _default._warn_about_problematic_credentials(credentials) - - def load_credentials_from_file(filename, scopes=None, quota_project_id=None): """Loads Google credentials from a file. @@ -91,7 +81,7 @@ def load_credentials_from_file(filename, scopes=None, quota_project_id=None): new_exc = exceptions.DefaultCredentialsError(msg, caught_exc) six.raise_from(new_exc, caught_exc) if not credentials.quota_project_id: - _warn_about_problematic_credentials(credentials) + _default._warn_about_problematic_credentials(credentials) return credentials, None elif credential_type == _default._SERVICE_ACCOUNT_TYPE: From 77d7b6e1c45a7982b3d161e755129ccf44377bcc Mon Sep 17 00:00:00 2001 From: AniBadde Date: Thu, 30 Jul 2020 17:59:54 -0500 Subject: [PATCH 20/20] fix: removed oauth2 client --- google/auth/_oauth2client_async.py | 171 ------------------------ noxfile.py | 2 +- tests_async/test__oauth2client_async.py | 170 ----------------------- 3 files changed, 1 insertion(+), 342 deletions(-) delete mode 100644 google/auth/_oauth2client_async.py delete mode 100644 tests_async/test__oauth2client_async.py diff --git a/google/auth/_oauth2client_async.py b/google/auth/_oauth2client_async.py deleted file mode 100644 index 70e399738..000000000 --- a/google/auth/_oauth2client_async.py +++ /dev/null @@ -1,171 +0,0 @@ -# Copyright 2020 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Helpers for transitioning from oauth2client to google-auth. - -.. warning:: - This module is private as it is intended to assist first-party downstream - clients with the transition from oauth2client to google-auth. -""" - -from __future__ import absolute_import - -import six - -from google.auth import _oauth2client -import google.auth.app_engine -import google.auth.compute_engine -import google.oauth2.service_account_async - - -try: - import oauth2client - import oauth2client.client - import oauth2client.contrib.gce - import oauth2client.service_account -except ImportError as caught_exc: - six.raise_from(ImportError("oauth2client is not installed."), caught_exc) - -try: - import oauth2client.contrib.appengine # pytype: disable=import-error - - _HAS_APPENGINE = True -except ImportError: - _HAS_APPENGINE = False - - -def _convert_oauth2_credentials(credentials): - """Converts to :class:`google.oauth2.credentials_async.Credentials`. - - Args: - credentials (Union[oauth2client.client.OAuth2Credentials, - oauth2client.client.GoogleCredentials]): The credentials to - convert. - - Returns: - google.oauth2.credentials_async.Credentials: The converted credentials. - """ - new_credentials = google.oauth2.credentials_async.Credentials( - token=credentials.access_token, - refresh_token=credentials.refresh_token, - token_uri=credentials.token_uri, - client_id=credentials.client_id, - client_secret=credentials.client_secret, - scopes=credentials.scopes, - ) - - new_credentials._expires = credentials.token_expiry - - return new_credentials - - -def _convert_service_account_credentials(credentials): - """Converts to :class:`google.oauth2.service_account_async.Credentials`. - - Args: - credentials (Union[ - oauth2client.service_account_async.ServiceAccountCredentials, - oauth2client.service_account_async._JWTAccessCredentials]): The - credentials to convert. - - Returns: - google.oauth2.service_account_async.Credentials: The converted credentials. - """ - info = credentials.serialization_data.copy() - info["token_uri"] = credentials.token_uri - return google.oauth2.service_account_async.Credentials.from_service_account_info( - info - ) - - -def _convert_gce_app_assertion_credentials(credentials): - """Converts to :class:`google.auth.compute_engine.Credentials`. - - Args: - credentials (oauth2client.contrib.gce.AppAssertionCredentials): The - credentials to convert. - - Returns: - google.oauth2.service_account_async.Credentials: The converted credentials. - """ - return _oauth2client._convert_gce_app_assertion_credentials(credentials) - - -def _convert_appengine_app_assertion_credentials(credentials): - """Converts to :class:`google.auth.app_engine.Credentials`. - - Args: - credentials (oauth2client.contrib.app_engine.AppAssertionCredentials): - The credentials to convert. - - Returns: - google.oauth2.service_account_async.Credentials: The converted credentials. - """ - # pylint: disable=invalid-name - return _oauth2client._convert_appengine_app_assertion_credentials(credentials) - - -_CLASS_CONVERSION_MAP = { - oauth2client.client.OAuth2Credentials: _convert_oauth2_credentials, - oauth2client.client.GoogleCredentials: _convert_oauth2_credentials, - oauth2client.service_account.ServiceAccountCredentials: _convert_service_account_credentials, - oauth2client.service_account._JWTAccessCredentials: _convert_service_account_credentials, - oauth2client.contrib.gce.AppAssertionCredentials: _convert_gce_app_assertion_credentials, -} - -if _HAS_APPENGINE: - _CLASS_CONVERSION_MAP[ - oauth2client.contrib.appengine.AppAssertionCredentials - ] = _convert_appengine_app_assertion_credentials - - -def convert(credentials): - """Convert oauth2client credentials to google-auth credentials. - - This class converts: - - - :class:`oauth2client.client.OAuth2Credentials` to - :class:`google.oauth2.credentials_async.Credentials`. - - :class:`oauth2client.client.GoogleCredentials` to - :class:`google.oauth2.credentials_async.Credentials`. - - :class:`oauth2client.service_account.ServiceAccountCredentials` to - :class:`google.oauth2.service_account_async.Credentials`. - - :class:`oauth2client.service_account._JWTAccessCredentials` to - :class:`google.oauth2.service_account_async.Credentials`. - - :class:`oauth2client.contrib.gce.AppAssertionCredentials` to - :class:`google.auth.compute_engine.Credentials`. - - :class:`oauth2client.contrib.appengine.AppAssertionCredentials` to - :class:`google.auth.app_engine.Credentials`. - - Returns: - google.auth.credentials_async.Credentials: The converted credentials. - - Raises: - ValueError: If the credentials could not be converted. - """ - - """ - Not able to inherit from the sync _oauth2client file (calling _oauth2client.convert), - and therefore reproduced here. - """ - - credentials_class = type(credentials) - - try: - return _CLASS_CONVERSION_MAP[credentials_class](credentials) - except KeyError as caught_exc: - new_exc = ValueError( - _oauth2client._CONVERT_ERROR_TMPL.format(credentials_class) - ) - six.raise_from(new_exc, caught_exc) diff --git a/noxfile.py b/noxfile.py index 17213a9d0..7d42e60ac 100644 --- a/noxfile.py +++ b/noxfile.py @@ -1,4 +1,4 @@ -# Copyright 2020 Google LLC +# Copyright 2019 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/tests_async/test__oauth2client_async.py b/tests_async/test__oauth2client_async.py deleted file mode 100644 index af38dedda..000000000 --- a/tests_async/test__oauth2client_async.py +++ /dev/null @@ -1,170 +0,0 @@ -# Copyright 2020 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import datetime -import sys - -import mock -import oauth2client.client -import oauth2client.contrib.gce -import oauth2client.service_account -import pytest -from six.moves import reload_module - -from google.auth import _oauth2client_async as _oauth2client -from tests import test__oauth2client as test_oauth2client - - -def test__convert_oauth2_credentials(): - old_credentials = oauth2client.client.OAuth2Credentials( - "access_token", - "client_id", - "client_secret", - "refresh_token", - datetime.datetime.min, - "token_uri", - "user_agent", - scopes="one two", - ) - - new_credentials = _oauth2client._convert_oauth2_credentials(old_credentials) - - assert new_credentials.token == old_credentials.access_token - assert new_credentials._refresh_token == old_credentials.refresh_token - assert new_credentials._client_id == old_credentials.client_id - assert new_credentials._client_secret == old_credentials.client_secret - assert new_credentials._token_uri == old_credentials.token_uri - assert new_credentials.scopes == old_credentials.scopes - - -def test__convert_service_account_credentials(): - old_class = oauth2client.service_account.ServiceAccountCredentials - old_credentials = old_class.from_json_keyfile_name( - test_oauth2client.SERVICE_ACCOUNT_JSON_FILE - ) - - new_credentials = _oauth2client._convert_service_account_credentials( - old_credentials - ) - - assert ( - new_credentials.service_account_email == old_credentials.service_account_email - ) - assert new_credentials._signer.key_id == old_credentials._private_key_id - assert new_credentials._token_uri == old_credentials.token_uri - - -def test__convert_service_account_credentials_with_jwt(): - old_class = oauth2client.service_account._JWTAccessCredentials - old_credentials = old_class.from_json_keyfile_name( - test_oauth2client.SERVICE_ACCOUNT_JSON_FILE - ) - - new_credentials = _oauth2client._convert_service_account_credentials( - old_credentials - ) - - assert ( - new_credentials.service_account_email == old_credentials.service_account_email - ) - assert new_credentials._signer.key_id == old_credentials._private_key_id - assert new_credentials._token_uri == old_credentials.token_uri - - -def test__convert_gce_app_assertion_credentials(): - old_credentials = oauth2client.contrib.gce.AppAssertionCredentials( - email="some_email" - ) - - new_credentials = _oauth2client._convert_gce_app_assertion_credentials( - old_credentials - ) - - assert ( - new_credentials.service_account_email == old_credentials.service_account_email - ) - - -@pytest.fixture -def mock_oauth2client_gae_imports(mock_non_existent_module): - mock_non_existent_module("google.appengine.api.app_identity") - mock_non_existent_module("google.appengine.ext.ndb") - mock_non_existent_module("google.appengine.ext.webapp.util") - mock_non_existent_module("webapp2") - - -@mock.patch("google.auth.app_engine.app_identity") -def test__convert_appengine_app_assertion_credentials( - app_identity, mock_oauth2client_gae_imports -): - - import oauth2client.contrib.appengine - - service_account_id = "service_account_id" - old_credentials = oauth2client.contrib.appengine.AppAssertionCredentials( - scope="one two", service_account_id=service_account_id - ) - - new_credentials = _oauth2client._convert_appengine_app_assertion_credentials( - old_credentials - ) - - assert new_credentials.scopes == ["one", "two"] - assert new_credentials._service_account_id == old_credentials.service_account_id - - -class FakeCredentials(object): - pass - - -def test_convert_success(): - convert_function = mock.Mock(spec=["__call__"]) - conversion_map_patch = mock.patch.object( - _oauth2client, "_CLASS_CONVERSION_MAP", {FakeCredentials: convert_function} - ) - credentials = FakeCredentials() - - with conversion_map_patch: - result = _oauth2client.convert(credentials) - - convert_function.assert_called_once_with(credentials) - assert result == convert_function.return_value - - -def test_convert_not_found(): - with pytest.raises(ValueError) as excinfo: - _oauth2client.convert("a string is not a real credentials class") - - assert excinfo.match("Unable to convert") - - -@pytest.fixture -def reset__oauth2client_module(): - """Reloads the _oauth2client module after a test.""" - reload_module(_oauth2client) - - -def test_import_has_app_engine( - mock_oauth2client_gae_imports, reset__oauth2client_module -): - reload_module(_oauth2client) - assert _oauth2client._HAS_APPENGINE - - -def test_import_without_oauth2client(monkeypatch, reset__oauth2client_module): - monkeypatch.setitem(sys.modules, "oauth2client", None) - with pytest.raises(ImportError) as excinfo: - reload_module(_oauth2client) - - assert excinfo.match("oauth2client")