diff --git a/google/auth/credentials_async.py b/google/auth/credentials_async.py new file mode 100644 index 000000000..a131cc44b --- /dev/null +++ b/google/auth/credentials_async.py @@ -0,0 +1,168 @@ +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +"""Interfaces for credentials.""" + +import abc + +import six + +from google.auth import credentials + + +@six.add_metaclass(abc.ABCMeta) +class Credentials(credentials.Credentials): + """Async inherited credentials class from google.auth.credentials. + The added functionality is the before_request call which requires + async/await syntax. + All credentials have a :attr:`token` that is used for authentication and + may also optionally set an :attr:`expiry` to indicate when the token will + no longer be valid. + + Most credentials will be :attr:`invalid` until :meth:`refresh` is called. + Credentials can do this automatically before the first HTTP request in + :meth:`before_request`. + + Although the token and expiration will change as the credentials are + :meth:`refreshed ` and used, credentials should be considered + immutable. Various credentials will accept configuration such as private + keys, scopes, and other options. These options are not changeable after + construction. Some classes will provide mechanisms to copy the credentials + with modifications such as :meth:`ScopedCredentials.with_scopes`. + """ + + async def before_request(self, request, method, url, headers): + """Performs credential-specific before request logic. + + Refreshes the credentials if necessary, then calls :meth:`apply` to + apply the token to the authentication header. + + Args: + request (google.auth.transport.Request): The object used to make + HTTP requests. + method (str): The request's HTTP method or the RPC method being + invoked. + url (str): The request's URI or the RPC service's URI. + headers (Mapping): The request's headers. + """ + # pylint: disable=unused-argument + # (Subclasses may use these arguments to ascertain information about + # the http request.) + + if not self.valid: + self.refresh(request) + self.apply(headers) + + +class AnonymousCredentials(credentials.AnonymousCredentials, Credentials): + """Credentials that do not provide any authentication information. + + These are useful in the case of services that support anonymous access or + local service emulators that do not use credentials. This class inherits + from the sync anonymous credentials file, but is kept if async credentials + is initialized and we would like anonymous credentials. + """ + + +@six.add_metaclass(abc.ABCMeta) +class ReadOnlyScoped(credentials.ReadOnlyScoped): + """Interface for credentials whose scopes can be queried. + + OAuth 2.0-based credentials allow limiting access using scopes as described + in `RFC6749 Section 3.3`_. + If a credential class implements this interface then the credentials either + use scopes in their implementation. + + Some credentials require scopes in order to obtain a token. You can check + if scoping is necessary with :attr:`requires_scopes`:: + + if credentials.requires_scopes: + # Scoping is required. + credentials = credentials_async.with_scopes(scopes=['one', 'two']) + + Credentials that require scopes must either be constructed with scopes:: + + credentials = SomeScopedCredentials(scopes=['one', 'two']) + + Or must copy an existing instance using :meth:`with_scopes`:: + + scoped_credentials = credentials_async.with_scopes(scopes=['one', 'two']) + + Some credentials have scopes but do not allow or require scopes to be set, + these credentials can be used as-is. + + .. _RFC6749 Section 3.3: https://tools.ietf.org/html/rfc6749#section-3.3 + """ + + +class Scoped(credentials.Scoped): + """Interface for credentials whose scopes can be replaced while copying. + + OAuth 2.0-based credentials allow limiting access using scopes as described + in `RFC6749 Section 3.3`_. + If a credential class implements this interface then the credentials either + use scopes in their implementation. + + Some credentials require scopes in order to obtain a token. You can check + if scoping is necessary with :attr:`requires_scopes`:: + + if credentials.requires_scopes: + # Scoping is required. + credentials = credentials_async.create_scoped(['one', 'two']) + + Credentials that require scopes must either be constructed with scopes:: + + credentials = SomeScopedCredentials(scopes=['one', 'two']) + + Or must copy an existing instance using :meth:`with_scopes`:: + + scoped_credentials = credentials.with_scopes(scopes=['one', 'two']) + + Some credentials have scopes but do not allow or require scopes to be set, + these credentials can be used as-is. + + .. _RFC6749 Section 3.3: https://tools.ietf.org/html/rfc6749#section-3.3 + """ + + +def with_scopes_if_required(credentials, scopes): + """Creates a copy of the credentials with scopes if scoping is required. + + This helper function is useful when you do not know (or care to know) the + specific type of credentials you are using (such as when you use + :func:`google.auth.default`). This function will call + :meth:`Scoped.with_scopes` if the credentials are scoped credentials and if + the credentials require scoping. Otherwise, it will return the credentials + as-is. + + Args: + credentials (google.auth.credentials.Credentials): The credentials to + scope if necessary. + scopes (Sequence[str]): The list of scopes to use. + + Returns: + google.auth.credentials_async.Credentials: Either a new set of scoped + credentials, or the passed in credentials instance if no scoping + was required. + """ + if isinstance(credentials, Scoped) and credentials.requires_scopes: + return credentials.with_scopes(scopes) + else: + return credentials + + +@six.add_metaclass(abc.ABCMeta) +class Signing(credentials.Signing): + """Interface for credentials that can cryptographically sign messages.""" diff --git a/google/auth/transport/aiohttp_requests.py b/google/auth/transport/aiohttp_requests.py new file mode 100644 index 000000000..46816ea5e --- /dev/null +++ b/google/auth/transport/aiohttp_requests.py @@ -0,0 +1,313 @@ +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Transport adapter for Async HTTP (aiohttp).""" + +from __future__ import absolute_import + +import asyncio +import functools + +import aiohttp +import six + +from google.auth import exceptions +from google.auth import transport +from google.auth.transport import requests + +# Timeout can be re-defined depending on async requirement. Currently made 60s more than +# sync timeout. +_DEFAULT_TIMEOUT = 180 # in seconds + + +class _Response(transport.Response): + """ + Requests transport response adapter. + + Args: + response (requests.Response): The raw Requests response. + """ + + def __init__(self, response): + self._response = response + + @property + def status(self): + return self._response.status + + @property + def headers(self): + return self._response.headers + + @property + def data(self): + return self._response.content + + +class Request(transport.Request): + """Requests request adapter. + + This class is used internally for making requests using asyncio transports + in a consistent way. If you use :class:`AuthorizedSession` you do not need + to construct or use this class directly. + + This class can be useful if you want to manually refresh a + :class:`~google.auth.credentials.Credentials` instance:: + + import google.auth.transport.aiohttp_requests + + request = google.auth.transport.aiohttp_requests.Request() + + credentials.refresh(request) + + Args: + session (aiohttp.ClientSession): An instance :class: aiohttp.ClientSession used + to make HTTP requests. If not specified, a session will be created. + + .. automethod:: __call__ + """ + + def __init__(self, session=None): + + self.session = None + + async def __call__( + self, + url, + method="GET", + body=None, + headers=None, + timeout=_DEFAULT_TIMEOUT, + **kwargs + ): + """ + Make an HTTP request using aiohttp. + + Args: + url (str): The URL to be requested. + method (str): The HTTP method to use for the request. Defaults + to 'GET'. + body (bytes): The payload / body in HTTP request. + headers (Mapping[str, str]): Request headers. + timeout (Optional[int]): The number of seconds to wait for a + response from the server. If not specified or if None, the + requests default timeout will be used. + kwargs: Additional arguments passed through to the underlying + requests :meth:`~requests.Session.request` method. + + Returns: + google.auth.transport.Response: The HTTP response. + + Raises: + google.auth.exceptions.TransportError: If any exception occurred. + """ + + try: + if self.session is None: # pragma: NO COVER + self.session = aiohttp.ClientSession() # pragma: NO COVER + requests._LOGGER.debug("Making request: %s %s", method, url) + response = await self.session.request( + method, url, data=body, headers=headers, timeout=timeout, **kwargs + ) + return _Response(response) + + except aiohttp.ClientError as caught_exc: + new_exc = exceptions.TransportError(caught_exc) + six.raise_from(new_exc, caught_exc) + + except asyncio.TimeoutError as caught_exc: + new_exc = exceptions.TransportError(caught_exc) + six.raise_from(new_exc, caught_exc) + + +class AuthorizedSession(aiohttp.ClientSession): + """This is an async implementation of the Authorized Session class. We utilize an + aiohttp transport instance, and the interface mirrors the google.auth.transport.requests + Authorized Session class, except for the change in the transport used in the async use case. + + A Requests Session class with credentials. + + This class is used to perform requests to API endpoints that require + authorization:: + + import google.auth.transport.aiohttp_requests + + async with aiohttp_requests.AuthorizedSession(credentials) as authed_session: + response = await authed_session.request( + 'GET', 'https://www.googleapis.com/storage/v1/b') + + The underlying :meth:`request` implementation handles adding the + credentials' headers to the request and refreshing credentials as needed. + + Args: + credentials (google.auth.credentials_async.Credentials): The credentials to + add to the request. + refresh_status_codes (Sequence[int]): Which HTTP status codes indicate + that credentials should be refreshed and the request should be + retried. + max_refresh_attempts (int): The maximum number of times to attempt to + refresh the credentials and retry the request. + refresh_timeout (Optional[int]): The timeout value in seconds for + credential refresh HTTP requests. + auth_request (google.auth.transport.aiohttp_requests.Request): + (Optional) An instance of + :class:`~google.auth.transport.aiohttp_requests.Request` used when + refreshing credentials. If not passed, + an instance of :class:`~google.auth.transport.aiohttp_requests.Request` + is created. + """ + + def __init__( + self, + credentials, + refresh_status_codes=transport.DEFAULT_REFRESH_STATUS_CODES, + max_refresh_attempts=transport.DEFAULT_MAX_REFRESH_ATTEMPTS, + refresh_timeout=None, + auth_request=None, + ): + super(AuthorizedSession, self).__init__() + self.credentials = credentials + self._refresh_status_codes = refresh_status_codes + self._max_refresh_attempts = max_refresh_attempts + self._refresh_timeout = refresh_timeout + self._is_mtls = False + self._auth_request = auth_request + self._auth_request_session = None + self._loop = asyncio.get_event_loop() + self._refresh_lock = asyncio.Lock() + + async def request( + self, + method, + url, + data=None, + headers=None, + max_allowed_time=None, + timeout=_DEFAULT_TIMEOUT, + **kwargs + ): + + """Implementation of Authorized Session aiohttp request. + + Args: + method: The http request method used (e.g. GET, PUT, DELETE) + + url: The url at which the http request is sent. + + data, headers: These fields parallel the associated data and headers + fields of a regular http request. Using the aiohttp client session to + send the http request allows us to use this parallel corresponding structure + in our Authorized Session class. + + timeout (Optional[Union[float, Tuple[float, float]]]): + The amount of time in seconds to wait for the server response + with each individual request. + + Can also be passed as a tuple (connect_timeout, read_timeout). + See :meth:`requests.Session.request` documentation for details. + + max_allowed_time (Optional[float]): + If the method runs longer than this, a ``Timeout`` exception is + automatically raised. Unlike the ``timeout` parameter, this + value applies to the total method execution time, even if + multiple requests are made under the hood. + + Mind that it is not guaranteed that the timeout error is raised + at ``max_allowed_time`. It might take longer, for example, if + an underlying request takes a lot of time, but the request + itself does not timeout, e.g. if a large file is being + transmitted. The timout error will be raised after such + request completes. + """ + + async with aiohttp.ClientSession() as self._auth_request_session: + auth_request = Request(self._auth_request_session) + self._auth_request = auth_request + + # Use a kwarg for this instead of an attribute to maintain + # thread-safety. + _credential_refresh_attempt = kwargs.pop("_credential_refresh_attempt", 0) + # Make a copy of the headers. They will be modified by the credentials + # and we want to pass the original headers if we recurse. + request_headers = headers.copy() if headers is not None else {} + + # Do not apply the timeout unconditionally in order to not override the + # _auth_request's default timeout. + auth_request = ( + self._auth_request + if timeout is None + else functools.partial(self._auth_request, timeout=timeout) + ) + + remaining_time = max_allowed_time + + with requests.TimeoutGuard(remaining_time, asyncio.TimeoutError) as guard: + await self.credentials.before_request( + auth_request, method, url, request_headers + ) + + with requests.TimeoutGuard(remaining_time, asyncio.TimeoutError) as guard: + response = await super(AuthorizedSession, self).request( + method, + url, + data=data, + headers=request_headers, + timeout=timeout, + **kwargs + ) + + remaining_time = guard.remaining_timeout + + if ( + response.status in self._refresh_status_codes + and _credential_refresh_attempt < self._max_refresh_attempts + ): + + requests._LOGGER.info( + "Refreshing credentials due to a %s response. Attempt %s/%s.", + response.status, + _credential_refresh_attempt + 1, + self._max_refresh_attempts, + ) + + # Do not apply the timeout unconditionally in order to not override the + # _auth_request's default timeout. + auth_request = ( + self._auth_request + if timeout is None + else functools.partial(self._auth_request, timeout=timeout) + ) + + with requests.TimeoutGuard( + remaining_time, asyncio.TimeoutError + ) as guard: + async with self._refresh_lock: + await self._loop.run_in_executor( + None, self.credentials.refresh, auth_request + ) + + remaining_time = guard.remaining_timeout + + return await self.request( + method, + url, + data=data, + headers=headers, + max_allowed_time=remaining_time, + timeout=timeout, + _credential_refresh_attempt=_credential_refresh_attempt + 1, + **kwargs + ) + + return response diff --git a/noxfile.py b/noxfile.py index c39f27c47..7d42e60ac 100644 --- a/noxfile.py +++ b/noxfile.py @@ -29,8 +29,18 @@ "responses", "grpcio", ] + +ASYNC_DEPENDENCIES = ["pytest-asyncio", "aioresponses", "aiohttp"] + BLACK_VERSION = "black==19.3b0" -BLACK_PATHS = ["google", "tests", "noxfile.py", "setup.py", "docs/conf.py"] +BLACK_PATHS = [ + "google", + "tests", + "tests_async", + "noxfile.py", + "setup.py", + "docs/conf.py", +] @nox.session(python="3.7") @@ -44,6 +54,7 @@ def lint(session): "--application-import-names=google,tests,system_tests", "google", "tests", + "tests_async", ) session.run( "python", "setup.py", "check", "--metadata", "--restructuredtext", "--strict" @@ -64,8 +75,23 @@ def blacken(session): session.run("black", *BLACK_PATHS) -@nox.session(python=["2.7", "3.5", "3.6", "3.7", "3.8"]) +@nox.session(python=["3.6", "3.7", "3.8"]) def unit(session): + session.install(*TEST_DEPENDENCIES) + session.install(*(ASYNC_DEPENDENCIES)) + session.install(".") + session.run( + "pytest", + "--cov=google.auth", + "--cov=google.oauth2", + "--cov=tests", + "tests", + "tests_async", + ) + + +@nox.session(python=["2.7", "3.5"]) +def unit_prev_versions(session): session.install(*TEST_DEPENDENCIES) session.install(".") session.run( @@ -76,14 +102,17 @@ def unit(session): @nox.session(python="3.7") def cover(session): session.install(*TEST_DEPENDENCIES) + session.install(*(ASYNC_DEPENDENCIES)) session.install(".") session.run( "pytest", "--cov=google.auth", "--cov=google.oauth2", "--cov=tests", + "--cov=tests_async", "--cov-report=", "tests", + "tests_async", ) session.run("coverage", "report", "--show-missing", "--fail-under=100") @@ -117,5 +146,10 @@ def pypy(session): session.install(*TEST_DEPENDENCIES) session.install(".") session.run( - "pytest", "--cov=google.auth", "--cov=google.oauth2", "--cov=tests", "tests" + "pytest", + "--cov=google.auth", + "--cov=google.oauth2", + "--cov=tests", + "tests", + "tests_async", ) diff --git a/tests_async/test_credentials.py b/tests_async/test_credentials.py new file mode 100644 index 000000000..377f9a7e2 --- /dev/null +++ b/tests_async/test_credentials.py @@ -0,0 +1,183 @@ +# Copyright 2016 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import datetime + +import pytest + +from google.auth import _helpers +from google.auth import credentials_async as credentials + + +class CredentialsImpl(credentials.Credentials): + def refresh(self, request): + self.token = request + + def with_quota_project(self, quota_project_id): + raise NotImplementedError() + + +def test_credentials_constructor(): + credentials = CredentialsImpl() + assert not credentials.token + assert not credentials.expiry + assert not credentials.expired + assert not credentials.valid + + +def test_expired_and_valid(): + credentials = CredentialsImpl() + credentials.token = "token" + + assert credentials.valid + assert not credentials.expired + + # Set the expiration to one second more than now plus the clock skew + # accomodation. These credentials should be valid. + credentials.expiry = ( + datetime.datetime.utcnow() + _helpers.CLOCK_SKEW + datetime.timedelta(seconds=1) + ) + + assert credentials.valid + assert not credentials.expired + + # Set the credentials expiration to now. Because of the clock skew + # accomodation, these credentials should report as expired. + credentials.expiry = datetime.datetime.utcnow() + + assert not credentials.valid + assert credentials.expired + + +@pytest.mark.asyncio +async def test_before_request(): + credentials = CredentialsImpl() + request = "token" + headers = {} + + # First call should call refresh, setting the token. + await credentials.before_request(request, "http://example.com", "GET", headers) + assert credentials.valid + assert credentials.token == "token" + assert headers["authorization"] == "Bearer token" + + request = "token2" + headers = {} + + # Second call shouldn't call refresh. + credentials.before_request(request, "http://example.com", "GET", headers) + + assert credentials.valid + assert credentials.token == "token" + + +def test_anonymous_credentials_ctor(): + anon = credentials.AnonymousCredentials() + + assert anon.token is None + assert anon.expiry is None + assert not anon.expired + assert anon.valid + + +def test_anonymous_credentials_refresh(): + anon = credentials.AnonymousCredentials() + + request = object() + with pytest.raises(ValueError): + anon.refresh(request) + + +def test_anonymous_credentials_apply_default(): + anon = credentials.AnonymousCredentials() + headers = {} + anon.apply(headers) + assert headers == {} + with pytest.raises(ValueError): + anon.apply(headers, token="TOKEN") + + +def test_anonymous_credentials_before_request(): + anon = credentials.AnonymousCredentials() + request = object() + method = "GET" + url = "https://example.com/api/endpoint" + headers = {} + anon.before_request(request, method, url, headers) + assert headers == {} + + +def test_anonymous_credentials_with_quota_project(): + with pytest.raises(ValueError): + anon = credentials.AnonymousCredentials() + anon.with_quota_project("project-foo") + + +class ReadOnlyScopedCredentialsImpl(credentials.ReadOnlyScoped, CredentialsImpl): + @property + def requires_scopes(self): + return super(ReadOnlyScopedCredentialsImpl, self).requires_scopes + + +def test_readonly_scoped_credentials_constructor(): + credentials = ReadOnlyScopedCredentialsImpl() + assert credentials._scopes is None + + +def test_readonly_scoped_credentials_scopes(): + credentials = ReadOnlyScopedCredentialsImpl() + credentials._scopes = ["one", "two"] + assert credentials.scopes == ["one", "two"] + assert credentials.has_scopes(["one"]) + assert credentials.has_scopes(["two"]) + assert credentials.has_scopes(["one", "two"]) + assert not credentials.has_scopes(["three"]) + + +def test_readonly_scoped_credentials_requires_scopes(): + credentials = ReadOnlyScopedCredentialsImpl() + assert not credentials.requires_scopes + + +class RequiresScopedCredentialsImpl(credentials.Scoped, CredentialsImpl): + def __init__(self, scopes=None): + super(RequiresScopedCredentialsImpl, self).__init__() + self._scopes = scopes + + @property + def requires_scopes(self): + return not self.scopes + + def with_scopes(self, scopes): + return RequiresScopedCredentialsImpl(scopes=scopes) + + +def test_create_scoped_if_required_scoped(): + unscoped_credentials = RequiresScopedCredentialsImpl() + scoped_credentials = credentials.with_scopes_if_required( + unscoped_credentials, ["one", "two"] + ) + + assert scoped_credentials is not unscoped_credentials + assert not scoped_credentials.requires_scopes + assert scoped_credentials.has_scopes(["one", "two"]) + + +def test_create_scoped_if_required_not_scopes(): + unscoped_credentials = CredentialsImpl() + scoped_credentials = credentials.with_scopes_if_required( + unscoped_credentials, ["one", "two"] + ) + + assert scoped_credentials is unscoped_credentials diff --git a/tests_async/transport/__init__.py b/tests_async/transport/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests_async/transport/async_compliance.py b/tests_async/transport/async_compliance.py new file mode 100644 index 000000000..0f204bd56 --- /dev/null +++ b/tests_async/transport/async_compliance.py @@ -0,0 +1,136 @@ +# Copyright 2016 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import time + +import flask +import pytest +from pytest_localserver.http import WSGIServer +from six.moves import http_client + + +from google.auth import exceptions + +# .invalid will never resolve, see https://tools.ietf.org/html/rfc2606 +NXDOMAIN = "test.invalid" + + +class RequestResponseTests(object): + @pytest.fixture(scope="module") + def server(self): + """Provides a test HTTP server. + + The test server is automatically created before + a test and destroyed at the end. The server is serving a test + application that can be used to verify requests. + """ + app = flask.Flask(__name__) + app.debug = True + + # pylint: disable=unused-variable + # (pylint thinks the flask routes are unusued.) + @app.route("/basic") + def index(): + header_value = flask.request.headers.get("x-test-header", "value") + headers = {"X-Test-Header": header_value} + return "Basic Content", http_client.OK, headers + + @app.route("/server_error") + def server_error(): + return "Error", http_client.INTERNAL_SERVER_ERROR + + @app.route("/wait") + def wait(): + time.sleep(3) + return "Waited" + + # pylint: enable=unused-variable + + server = WSGIServer(application=app.wsgi_app) + server.start() + yield server + server.stop() + + @pytest.mark.asyncio + async def test_request_basic(self, server): + request = self.make_request() + response = await request(url=server.url + "/basic", method="GET") + assert response.status == http_client.OK + assert response.headers["x-test-header"] == "value" + + # Use 13 as this is the length of the data written into the stream. + + data = await response.data.read(13) + assert data == b"Basic Content" + + @pytest.mark.asyncio + async def test_request_basic_with_http(self, server): + request = self.make_with_parameter_request() + response = await request(url=server.url + "/basic", method="GET") + assert response.status == http_client.OK + assert response.headers["x-test-header"] == "value" + + # Use 13 as this is the length of the data written into the stream. + + data = await response.data.read(13) + assert data == b"Basic Content" + + @pytest.mark.asyncio + async def test_request_with_timeout_success(self, server): + request = self.make_request() + response = await request(url=server.url + "/basic", method="GET", timeout=2) + + assert response.status == http_client.OK + assert response.headers["x-test-header"] == "value" + + data = await response.data.read(13) + assert data == b"Basic Content" + + @pytest.mark.asyncio + async def test_request_with_timeout_failure(self, server): + request = self.make_request() + + with pytest.raises(exceptions.TransportError): + await request(url=server.url + "/wait", method="GET", timeout=1) + + @pytest.mark.asyncio + async def test_request_headers(self, server): + request = self.make_request() + response = await request( + url=server.url + "/basic", + method="GET", + headers={"x-test-header": "hello world"}, + ) + + assert response.status == http_client.OK + assert response.headers["x-test-header"] == "hello world" + + data = await response.data.read(13) + assert data == b"Basic Content" + + @pytest.mark.asyncio + async def test_request_error(self, server): + request = self.make_request() + + response = await request(url=server.url + "/server_error", method="GET") + assert response.status == http_client.INTERNAL_SERVER_ERROR + data = await response.data.read(5) + assert data == b"Error" + + @pytest.mark.asyncio + async def test_connection_error(self): + request = self.make_request() + + with pytest.raises(exceptions.TransportError): + await request(url="http://{}".format(NXDOMAIN), method="GET") diff --git a/tests_async/transport/test_aiohttp_requests.py b/tests_async/transport/test_aiohttp_requests.py new file mode 100644 index 000000000..4c3d9717b --- /dev/null +++ b/tests_async/transport/test_aiohttp_requests.py @@ -0,0 +1,163 @@ +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import aiohttp +from aioresponses import aioresponses +import mock +import pytest +from tests_async.transport import async_compliance + +import google.auth.credentials_async +from google.auth.transport import aiohttp_requests +import google.auth.transport._mtls_helper + + +class TestRequestResponse(async_compliance.RequestResponseTests): + def make_request(self): + return aiohttp_requests.Request() + + def make_with_parameter_request(self): + http = mock.create_autospec(aiohttp.ClientSession, instance=True) + return aiohttp_requests.Request(http) + + def test_timeout(self): + http = mock.create_autospec(aiohttp.ClientSession, instance=True) + request = google.auth.transport.aiohttp_requests.Request(http) + request(url="http://example.com", method="GET", timeout=5) + + +class CredentialsStub(google.auth.credentials_async.Credentials): + def __init__(self, token="token"): + super(CredentialsStub, self).__init__() + self.token = token + + def apply(self, headers, token=None): + headers["authorization"] = self.token + + def refresh(self, request): + self.token += "1" + + +class TestAuthorizedSession(object): + TEST_URL = "http://example.com/" + method = "GET" + + def test_constructor(self): + authed_session = google.auth.transport.aiohttp_requests.AuthorizedSession( + mock.sentinel.credentials + ) + assert authed_session.credentials == mock.sentinel.credentials + + def test_constructor_with_auth_request(self): + http = mock.create_autospec(aiohttp.ClientSession) + auth_request = google.auth.transport.aiohttp_requests.Request(http) + + authed_session = google.auth.transport.aiohttp_requests.AuthorizedSession( + mock.sentinel.credentials, auth_request=auth_request + ) + + assert authed_session._auth_request == auth_request + + @pytest.mark.asyncio + async def test_request(self): + with aioresponses() as mocked: + credentials = mock.Mock(wraps=CredentialsStub()) + + mocked.get(self.TEST_URL, status=200, body="test") + session = aiohttp_requests.AuthorizedSession(credentials) + resp = await session.request("GET", "http://example.com/") + + assert resp.status == 200 + assert "test" == await resp.text() + + await session.close() + + @pytest.mark.asyncio + async def test_ctx(self): + with aioresponses() as mocked: + credentials = mock.Mock(wraps=CredentialsStub()) + mocked.get("http://test.example.com", payload=dict(foo="bar")) + session = aiohttp_requests.AuthorizedSession(credentials) + resp = await session.request("GET", "http://test.example.com") + data = await resp.json() + + assert dict(foo="bar") == data + + await session.close() + + @pytest.mark.asyncio + async def test_http_headers(self): + with aioresponses() as mocked: + credentials = mock.Mock(wraps=CredentialsStub()) + mocked.post( + "http://example.com", + payload=dict(), + headers=dict(connection="keep-alive"), + ) + + session = aiohttp_requests.AuthorizedSession(credentials) + resp = await session.request("POST", "http://example.com") + + assert resp.headers["Connection"] == "keep-alive" + + await session.close() + + @pytest.mark.asyncio + async def test_regexp_example(self): + with aioresponses() as mocked: + credentials = mock.Mock(wraps=CredentialsStub()) + mocked.get("http://example.com", status=500) + mocked.get("http://example.com", status=200) + + session1 = aiohttp_requests.AuthorizedSession(credentials) + + resp1 = await session1.request("GET", "http://example.com") + session2 = aiohttp_requests.AuthorizedSession(credentials) + resp2 = await session2.request("GET", "http://example.com") + + assert resp1.status == 500 + assert resp2.status == 200 + + await session1.close() + await session2.close() + + @pytest.mark.asyncio + async def test_request_no_refresh(self): + credentials = mock.Mock(wraps=CredentialsStub()) + with aioresponses() as mocked: + mocked.get("http://example.com", status=200) + authed_session = google.auth.transport.aiohttp_requests.AuthorizedSession( + credentials + ) + response = await authed_session.request("GET", "http://example.com") + assert response.status == 200 + assert credentials.before_request.called + assert not credentials.refresh.called + + await authed_session.close() + + @pytest.mark.asyncio + async def test_request_refresh(self): + credentials = mock.Mock(wraps=CredentialsStub()) + with aioresponses() as mocked: + mocked.get("http://example.com", status=401) + mocked.get("http://example.com", status=200) + authed_session = google.auth.transport.aiohttp_requests.AuthorizedSession( + credentials + ) + response = await authed_session.request("GET", "http://example.com") + assert credentials.refresh.called + assert response.status == 200 + + await authed_session.close()