Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add timeout parameter to AuthorizedSession.request() #406

Merged
merged 2 commits into from
Dec 12, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
111 changes: 100 additions & 11 deletions google/auth/transport/requests.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@

import functools
import logging
import numbers
import time

try:
import requests
Expand Down Expand Up @@ -64,6 +66,50 @@ def data(self):
return self._response.content


class TimeoutGuard(object):
"""A context manager raising an error if the suite execution took too long.

Args:
timeout ([Union[None, float, Tuple[float, float]]]):
The maximum number of seconds a suite can run without the context
manager raising a timeout exception on exit. If passed as a tuple,
the smaller of the values is taken as a timeout. If ``None``, a
timeout error is never raised.
timeout_error_type (Optional[Exception]):
The type of the error to raise on timeout. Defaults to
:class:`requests.exceptions.Timeout`.
"""

def __init__(self, timeout, timeout_error_type=requests.exceptions.Timeout):
self._timeout = timeout
self.remaining_timeout = timeout
self._timeout_error_type = timeout_error_type

def __enter__(self):
self._start = time.time()
return self

def __exit__(self, exc_type, exc_value, traceback):
if exc_value:
return # let the error bubble up automatically

if self._timeout is None:
return # nothing to do, the timeout was not specified

elapsed = time.time() - self._start
deadline_hit = False

if isinstance(self._timeout, numbers.Number):
self.remaining_timeout = self._timeout - elapsed
deadline_hit = self.remaining_timeout <= 0
else:
self.remaining_timeout = tuple(x - elapsed for x in self._timeout)
deadline_hit = min(self.remaining_timeout) <= 0

if deadline_hit:
raise self._timeout_error_type()


class Request(transport.Request):
"""Requests request adapter.

Expand Down Expand Up @@ -193,8 +239,19 @@ def __init__(
# credentials.refresh).
self._auth_request = auth_request

def request(self, method, url, data=None, headers=None, **kwargs):
"""Implementation of Requests' request."""
def request(self, method, url, data=None, headers=None, timeout=None, **kwargs):
"""Implementation of Requests' request.

Args:
timeout (Optional[Union[float, Tuple[float, float]]]): The number
of seconds to wait before raising a ``Timeout`` exception. If
multiple requests are made under the hood, ``timeout`` is
interpreted as the approximate total time of **all** requests.

If passed as a tuple ``(connect_timeout, read_timeout)``, the
smaller of the values is taken as the total allowed time across
all requests.
"""
# pylint: disable=arguments-differ
# Requests has a ton of arguments to request, but only two
# (method, url) are required. We pass through all of the other
Expand All @@ -208,13 +265,28 @@ def request(self, method, url, data=None, headers=None, **kwargs):
# and we want to pass the original headers if we recurse.
request_headers = headers.copy() if headers is not None else {}

self.credentials.before_request(
self._auth_request, method, url, request_headers
# Do not apply the timeout unconditionally in order to not override the
# _auth_request's default timeout.
auth_request = (
self._auth_request
if timeout is None
else functools.partial(self._auth_request, timeout=timeout)
)

response = super(AuthorizedSession, self).request(
method, url, data=data, headers=request_headers, **kwargs
)
with TimeoutGuard(timeout) as guard:
self.credentials.before_request(auth_request, method, url, request_headers)
timeout = guard.remaining_timeout

with TimeoutGuard(timeout) as guard:
response = super(AuthorizedSession, self).request(
method,
url,
data=data,
headers=request_headers,
timeout=timeout,
**kwargs
)
timeout = guard.remaining_timeout

# If the response indicated that the credentials needed to be
# refreshed, then refresh the credentials and re-attempt the
Expand All @@ -233,17 +305,34 @@ def request(self, method, url, data=None, headers=None, **kwargs):
self._max_refresh_attempts,
)

auth_request_with_timeout = functools.partial(
self._auth_request, timeout=self._refresh_timeout
if self._refresh_timeout is not None:
if timeout is None:
timeout = self._refresh_timeout
elif isinstance(timeout, numbers.Number):
timeout = min(timeout, self._refresh_timeout)
else:
timeout = tuple(min(x, self._refresh_timeout) for x in timeout)

# Do not apply the timeout unconditionally in order to not override the
# _auth_request's default timeout.
auth_request = (
self._auth_request
if timeout is None
else functools.partial(self._auth_request, timeout=timeout)
)
self.credentials.refresh(auth_request_with_timeout)

# Recurse. Pass in the original headers, not our modified set.
with TimeoutGuard(timeout) as guard:
self.credentials.refresh(auth_request)
timeout = guard.remaining_timeout

# Recurse. Pass in the original headers, not our modified set, but
# do pass the adjusted timeout (i.e. the remaining time).
return self.request(
method,
url,
data=data,
headers=headers,
timeout=timeout,
_credential_refresh_attempt=_credential_refresh_attempt + 1,
**kwargs
)
Expand Down
1 change: 1 addition & 0 deletions noxfile.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

TEST_DEPENDENCIES = [
"flask",
"freezegun",
"mock",
"oauth2client",
"pytest",
Expand Down
154 changes: 153 additions & 1 deletion tests/transport/test_requests.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import datetime
import functools

import freezegun
import mock
import pytest
import requests
import requests.adapters
from six.moves import http_client
Expand All @@ -22,6 +27,12 @@
from tests.transport import compliance


@pytest.fixture
def frozen_time():
with freezegun.freeze_time("1970-01-01 00:00:00", tick=False) as frozen:
yield frozen


class TestRequestResponse(compliance.RequestResponseTests):
def make_request(self):
return google.auth.transport.requests.Request()
Expand All @@ -34,6 +45,52 @@ def test_timeout(self):
assert http.request.call_args[1]["timeout"] == 5


class TestTimeoutGuard(object):
def make_guard(self, *args, **kwargs):
return google.auth.transport.requests.TimeoutGuard(*args, **kwargs)

def test_tracks_elapsed_time_w_numeric_timeout(self, frozen_time):
with self.make_guard(timeout=10) as guard:
frozen_time.tick(delta=3.8)
assert guard.remaining_timeout == 6.2

def test_tracks_elapsed_time_w_tuple_timeout(self, frozen_time):
with self.make_guard(timeout=(16, 19)) as guard:
frozen_time.tick(delta=3.8)
assert guard.remaining_timeout == (12.2, 15.2)

def test_noop_if_no_timeout(self, frozen_time):
with self.make_guard(timeout=None) as guard:
frozen_time.tick(delta=datetime.timedelta(days=3650))
# NOTE: no timeout error raised, despite years have passed
assert guard.remaining_timeout is None

def test_timeout_error_w_numeric_timeout(self, frozen_time):
with pytest.raises(requests.exceptions.Timeout):
with self.make_guard(timeout=10) as guard:
frozen_time.tick(delta=10.001)
assert guard.remaining_timeout == pytest.approx(-0.001)

def test_timeout_error_w_tuple_timeout(self, frozen_time):
with pytest.raises(requests.exceptions.Timeout):
with self.make_guard(timeout=(11, 10)) as guard:
frozen_time.tick(delta=10.001)
assert guard.remaining_timeout == pytest.approx((0.999, -0.001))

def test_custom_timeout_error_type(self, frozen_time):
class FooError(Exception):
pass

with pytest.raises(FooError):
with self.make_guard(timeout=1, timeout_error_type=FooError):
frozen_time.tick(2)

def test_lets_suite_errors_bubble_up(self, frozen_time):
with pytest.raises(IndexError):
with self.make_guard(timeout=1):
[1, 2, 3][3]


class CredentialsStub(google.auth.credentials.Credentials):
def __init__(self, token="token"):
super(CredentialsStub, self).__init__()
Expand All @@ -49,6 +106,18 @@ def refresh(self, request):
self.token += "1"


class TimeTickCredentialsStub(CredentialsStub):
"""Credentials that spend some (mocked) time when refreshing a token."""

def __init__(self, time_tick, token="token"):
self._time_tick = time_tick
super(TimeTickCredentialsStub, self).__init__(token=token)

def refresh(self, request):
self._time_tick()
super(TimeTickCredentialsStub, self).refresh(requests)


class AdapterStub(requests.adapters.BaseAdapter):
def __init__(self, responses, headers=None):
super(AdapterStub, self).__init__()
Expand All @@ -69,6 +138,18 @@ def close(self): # pragma: NO COVER
return


class TimeTickAdapterStub(AdapterStub):
"""Adapter that spends some (mocked) time when making a request."""

def __init__(self, time_tick, responses, headers=None):
self._time_tick = time_tick
super(TimeTickAdapterStub, self).__init__(responses, headers=headers)

def send(self, request, **kwargs):
self._time_tick()
return super(TimeTickAdapterStub, self).send(request, **kwargs)


def make_response(status=http_client.OK, data=None):
response = requests.Response()
response.status_code = status
Expand Down Expand Up @@ -121,7 +202,9 @@ def test_request_refresh(self):
[make_response(status=http_client.UNAUTHORIZED), final_response]
)

authed_session = google.auth.transport.requests.AuthorizedSession(credentials)
authed_session = google.auth.transport.requests.AuthorizedSession(
credentials, refresh_timeout=60
)
authed_session.mount(self.TEST_URL, adapter)

result = authed_session.request("GET", self.TEST_URL)
Expand All @@ -136,3 +219,72 @@ def test_request_refresh(self):

assert adapter.requests[1].url == self.TEST_URL
assert adapter.requests[1].headers["authorization"] == "token1"

def test_request_timeout(self, frozen_time):
tick_one_second = functools.partial(frozen_time.tick, delta=1.0)

credentials = mock.Mock(
wraps=TimeTickCredentialsStub(time_tick=tick_one_second)
)
adapter = TimeTickAdapterStub(
time_tick=tick_one_second,
responses=[
make_response(status=http_client.UNAUTHORIZED),
make_response(status=http_client.OK),
],
)

authed_session = google.auth.transport.requests.AuthorizedSession(credentials)
authed_session.mount(self.TEST_URL, adapter)

# Because at least two requests have to be made, and each takes one
# second, the total timeout specified will be exceeded.
with pytest.raises(requests.exceptions.Timeout):
authed_session.request("GET", self.TEST_URL, timeout=1.9)

def test_request_timeout_w_refresh_timeout(self, frozen_time):
tick_one_second = functools.partial(frozen_time.tick, delta=1.0)

credentials = mock.Mock(
wraps=TimeTickCredentialsStub(time_tick=tick_one_second)
)
adapter = TimeTickAdapterStub(
time_tick=tick_one_second,
responses=[
make_response(status=http_client.UNAUTHORIZED),
make_response(status=http_client.OK),
],
)

authed_session = google.auth.transport.requests.AuthorizedSession(
credentials, refresh_timeout=1.9
)
authed_session.mount(self.TEST_URL, adapter)

# The timeout is long, but the short refresh timeout will prevail.
with pytest.raises(requests.exceptions.Timeout):
authed_session.request("GET", self.TEST_URL, timeout=60)

def test_request_timeout_w_refresh_timeout_and_tuple_timeout(self, frozen_time):
tick_one_second = functools.partial(frozen_time.tick, delta=1.0)

credentials = mock.Mock(
wraps=TimeTickCredentialsStub(time_tick=tick_one_second)
)
adapter = TimeTickAdapterStub(
time_tick=tick_one_second,
responses=[
make_response(status=http_client.UNAUTHORIZED),
make_response(status=http_client.OK),
],
)

authed_session = google.auth.transport.requests.AuthorizedSession(
credentials, refresh_timeout=100
)
authed_session.mount(self.TEST_URL, adapter)

# The shortest timeout will prevail and cause a Timeout error, despite
# other timeouts being quite long.
with pytest.raises(requests.exceptions.Timeout):
authed_session.request("GET", self.TEST_URL, timeout=(100, 2.9))