Skip to content

Commit

Permalink
feat: add timeout parameter to AuthorizedSession.request() (#406)
Browse files Browse the repository at this point in the history
* feat: add timeout to AuthorisedSession.request()

* Add suport for timeout as a tuple to timeout guard

The `request.Request` class also accepts a timeout as a pair
(connect_timeout, read_timeout), and some downstream libraries use
this form.

This commit makes sure that the timeout logic correctly handles
timeouts as a two-tuple.

See also:
https://2.python-requests.org/en/master/user/advanced/#timeouts
  • Loading branch information
plamut authored and tswast committed Dec 12, 2019
1 parent b7f7d7d commit d86d7b8
Show file tree
Hide file tree
Showing 3 changed files with 254 additions and 12 deletions.
111 changes: 100 additions & 11 deletions google/auth/transport/requests.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@

import functools
import logging
import numbers
import time

try:
import requests
Expand Down Expand Up @@ -64,6 +66,50 @@ def data(self):
return self._response.content


class TimeoutGuard(object):
"""A context manager raising an error if the suite execution took too long.
Args:
timeout ([Union[None, float, Tuple[float, float]]]):
The maximum number of seconds a suite can run without the context
manager raising a timeout exception on exit. If passed as a tuple,
the smaller of the values is taken as a timeout. If ``None``, a
timeout error is never raised.
timeout_error_type (Optional[Exception]):
The type of the error to raise on timeout. Defaults to
:class:`requests.exceptions.Timeout`.
"""

def __init__(self, timeout, timeout_error_type=requests.exceptions.Timeout):
self._timeout = timeout
self.remaining_timeout = timeout
self._timeout_error_type = timeout_error_type

def __enter__(self):
self._start = time.time()
return self

def __exit__(self, exc_type, exc_value, traceback):
if exc_value:
return # let the error bubble up automatically

if self._timeout is None:
return # nothing to do, the timeout was not specified

elapsed = time.time() - self._start
deadline_hit = False

if isinstance(self._timeout, numbers.Number):
self.remaining_timeout = self._timeout - elapsed
deadline_hit = self.remaining_timeout <= 0
else:
self.remaining_timeout = tuple(x - elapsed for x in self._timeout)
deadline_hit = min(self.remaining_timeout) <= 0

if deadline_hit:
raise self._timeout_error_type()


class Request(transport.Request):
"""Requests request adapter.
Expand Down Expand Up @@ -193,8 +239,19 @@ def __init__(
# credentials.refresh).
self._auth_request = auth_request

def request(self, method, url, data=None, headers=None, **kwargs):
"""Implementation of Requests' request."""
def request(self, method, url, data=None, headers=None, timeout=None, **kwargs):
"""Implementation of Requests' request.
Args:
timeout (Optional[Union[float, Tuple[float, float]]]): The number
of seconds to wait before raising a ``Timeout`` exception. If
multiple requests are made under the hood, ``timeout`` is
interpreted as the approximate total time of **all** requests.
If passed as a tuple ``(connect_timeout, read_timeout)``, the
smaller of the values is taken as the total allowed time across
all requests.
"""
# pylint: disable=arguments-differ
# Requests has a ton of arguments to request, but only two
# (method, url) are required. We pass through all of the other
Expand All @@ -208,13 +265,28 @@ def request(self, method, url, data=None, headers=None, **kwargs):
# and we want to pass the original headers if we recurse.
request_headers = headers.copy() if headers is not None else {}

self.credentials.before_request(
self._auth_request, method, url, request_headers
# Do not apply the timeout unconditionally in order to not override the
# _auth_request's default timeout.
auth_request = (
self._auth_request
if timeout is None
else functools.partial(self._auth_request, timeout=timeout)
)

response = super(AuthorizedSession, self).request(
method, url, data=data, headers=request_headers, **kwargs
)
with TimeoutGuard(timeout) as guard:
self.credentials.before_request(auth_request, method, url, request_headers)
timeout = guard.remaining_timeout

with TimeoutGuard(timeout) as guard:
response = super(AuthorizedSession, self).request(
method,
url,
data=data,
headers=request_headers,
timeout=timeout,
**kwargs
)
timeout = guard.remaining_timeout

# If the response indicated that the credentials needed to be
# refreshed, then refresh the credentials and re-attempt the
Expand All @@ -233,17 +305,34 @@ def request(self, method, url, data=None, headers=None, **kwargs):
self._max_refresh_attempts,
)

auth_request_with_timeout = functools.partial(
self._auth_request, timeout=self._refresh_timeout
if self._refresh_timeout is not None:
if timeout is None:
timeout = self._refresh_timeout
elif isinstance(timeout, numbers.Number):
timeout = min(timeout, self._refresh_timeout)
else:
timeout = tuple(min(x, self._refresh_timeout) for x in timeout)

# Do not apply the timeout unconditionally in order to not override the
# _auth_request's default timeout.
auth_request = (
self._auth_request
if timeout is None
else functools.partial(self._auth_request, timeout=timeout)
)
self.credentials.refresh(auth_request_with_timeout)

# Recurse. Pass in the original headers, not our modified set.
with TimeoutGuard(timeout) as guard:
self.credentials.refresh(auth_request)
timeout = guard.remaining_timeout

# Recurse. Pass in the original headers, not our modified set, but
# do pass the adjusted timeout (i.e. the remaining time).
return self.request(
method,
url,
data=data,
headers=headers,
timeout=timeout,
_credential_refresh_attempt=_credential_refresh_attempt + 1,
**kwargs
)
Expand Down
1 change: 1 addition & 0 deletions noxfile.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

TEST_DEPENDENCIES = [
"flask",
"freezegun",
"mock",
"oauth2client",
"pytest",
Expand Down
154 changes: 153 additions & 1 deletion tests/transport/test_requests.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import datetime
import functools

import freezegun
import mock
import pytest
import requests
import requests.adapters
from six.moves import http_client
Expand All @@ -22,6 +27,12 @@
from tests.transport import compliance


@pytest.fixture
def frozen_time():
with freezegun.freeze_time("1970-01-01 00:00:00", tick=False) as frozen:
yield frozen


class TestRequestResponse(compliance.RequestResponseTests):
def make_request(self):
return google.auth.transport.requests.Request()
Expand All @@ -34,6 +45,52 @@ def test_timeout(self):
assert http.request.call_args[1]["timeout"] == 5


class TestTimeoutGuard(object):
def make_guard(self, *args, **kwargs):
return google.auth.transport.requests.TimeoutGuard(*args, **kwargs)

def test_tracks_elapsed_time_w_numeric_timeout(self, frozen_time):
with self.make_guard(timeout=10) as guard:
frozen_time.tick(delta=3.8)
assert guard.remaining_timeout == 6.2

def test_tracks_elapsed_time_w_tuple_timeout(self, frozen_time):
with self.make_guard(timeout=(16, 19)) as guard:
frozen_time.tick(delta=3.8)
assert guard.remaining_timeout == (12.2, 15.2)

def test_noop_if_no_timeout(self, frozen_time):
with self.make_guard(timeout=None) as guard:
frozen_time.tick(delta=datetime.timedelta(days=3650))
# NOTE: no timeout error raised, despite years have passed
assert guard.remaining_timeout is None

def test_timeout_error_w_numeric_timeout(self, frozen_time):
with pytest.raises(requests.exceptions.Timeout):
with self.make_guard(timeout=10) as guard:
frozen_time.tick(delta=10.001)
assert guard.remaining_timeout == pytest.approx(-0.001)

def test_timeout_error_w_tuple_timeout(self, frozen_time):
with pytest.raises(requests.exceptions.Timeout):
with self.make_guard(timeout=(11, 10)) as guard:
frozen_time.tick(delta=10.001)
assert guard.remaining_timeout == pytest.approx((0.999, -0.001))

def test_custom_timeout_error_type(self, frozen_time):
class FooError(Exception):
pass

with pytest.raises(FooError):
with self.make_guard(timeout=1, timeout_error_type=FooError):
frozen_time.tick(2)

def test_lets_suite_errors_bubble_up(self, frozen_time):
with pytest.raises(IndexError):
with self.make_guard(timeout=1):
[1, 2, 3][3]


class CredentialsStub(google.auth.credentials.Credentials):
def __init__(self, token="token"):
super(CredentialsStub, self).__init__()
Expand All @@ -49,6 +106,18 @@ def refresh(self, request):
self.token += "1"


class TimeTickCredentialsStub(CredentialsStub):
"""Credentials that spend some (mocked) time when refreshing a token."""

def __init__(self, time_tick, token="token"):
self._time_tick = time_tick
super(TimeTickCredentialsStub, self).__init__(token=token)

def refresh(self, request):
self._time_tick()
super(TimeTickCredentialsStub, self).refresh(requests)


class AdapterStub(requests.adapters.BaseAdapter):
def __init__(self, responses, headers=None):
super(AdapterStub, self).__init__()
Expand All @@ -69,6 +138,18 @@ def close(self): # pragma: NO COVER
return


class TimeTickAdapterStub(AdapterStub):
"""Adapter that spends some (mocked) time when making a request."""

def __init__(self, time_tick, responses, headers=None):
self._time_tick = time_tick
super(TimeTickAdapterStub, self).__init__(responses, headers=headers)

def send(self, request, **kwargs):
self._time_tick()
return super(TimeTickAdapterStub, self).send(request, **kwargs)


def make_response(status=http_client.OK, data=None):
response = requests.Response()
response.status_code = status
Expand Down Expand Up @@ -121,7 +202,9 @@ def test_request_refresh(self):
[make_response(status=http_client.UNAUTHORIZED), final_response]
)

authed_session = google.auth.transport.requests.AuthorizedSession(credentials)
authed_session = google.auth.transport.requests.AuthorizedSession(
credentials, refresh_timeout=60
)
authed_session.mount(self.TEST_URL, adapter)

result = authed_session.request("GET", self.TEST_URL)
Expand All @@ -136,3 +219,72 @@ def test_request_refresh(self):

assert adapter.requests[1].url == self.TEST_URL
assert adapter.requests[1].headers["authorization"] == "token1"

def test_request_timeout(self, frozen_time):
tick_one_second = functools.partial(frozen_time.tick, delta=1.0)

credentials = mock.Mock(
wraps=TimeTickCredentialsStub(time_tick=tick_one_second)
)
adapter = TimeTickAdapterStub(
time_tick=tick_one_second,
responses=[
make_response(status=http_client.UNAUTHORIZED),
make_response(status=http_client.OK),
],
)

authed_session = google.auth.transport.requests.AuthorizedSession(credentials)
authed_session.mount(self.TEST_URL, adapter)

# Because at least two requests have to be made, and each takes one
# second, the total timeout specified will be exceeded.
with pytest.raises(requests.exceptions.Timeout):
authed_session.request("GET", self.TEST_URL, timeout=1.9)

def test_request_timeout_w_refresh_timeout(self, frozen_time):
tick_one_second = functools.partial(frozen_time.tick, delta=1.0)

credentials = mock.Mock(
wraps=TimeTickCredentialsStub(time_tick=tick_one_second)
)
adapter = TimeTickAdapterStub(
time_tick=tick_one_second,
responses=[
make_response(status=http_client.UNAUTHORIZED),
make_response(status=http_client.OK),
],
)

authed_session = google.auth.transport.requests.AuthorizedSession(
credentials, refresh_timeout=1.9
)
authed_session.mount(self.TEST_URL, adapter)

# The timeout is long, but the short refresh timeout will prevail.
with pytest.raises(requests.exceptions.Timeout):
authed_session.request("GET", self.TEST_URL, timeout=60)

def test_request_timeout_w_refresh_timeout_and_tuple_timeout(self, frozen_time):
tick_one_second = functools.partial(frozen_time.tick, delta=1.0)

credentials = mock.Mock(
wraps=TimeTickCredentialsStub(time_tick=tick_one_second)
)
adapter = TimeTickAdapterStub(
time_tick=tick_one_second,
responses=[
make_response(status=http_client.UNAUTHORIZED),
make_response(status=http_client.OK),
],
)

authed_session = google.auth.transport.requests.AuthorizedSession(
credentials, refresh_timeout=100
)
authed_session.mount(self.TEST_URL, adapter)

# The shortest timeout will prevail and cause a Timeout error, despite
# other timeouts being quite long.
with pytest.raises(requests.exceptions.Timeout):
authed_session.request("GET", self.TEST_URL, timeout=(100, 2.9))

0 comments on commit d86d7b8

Please sign in to comment.