Skip to content

Commit

Permalink
feat: retryable HTTP status codes can be customized #21 #36
Browse files Browse the repository at this point in the history
Additional changes:
- retry logic has been fixed: `max_retries` now actully denotes the number of
  retries, not the number of total attempts
- the default waiting time between HTTP retries is now customizable
  • Loading branch information
mloesch committed May 3, 2020
1 parent d800757 commit b9dbaf4
Show file tree
Hide file tree
Showing 2 changed files with 102 additions and 26 deletions.
73 changes: 50 additions & 23 deletions sickle/app.py
Expand Up @@ -12,9 +12,9 @@
import time

import requests

from sickle.iterator import BaseOAIIterator, OAIItemIterator
from sickle.response import OAIResponse

from .models import (Set, Record, Header, MetadataFormat,
Identify)

Expand Down Expand Up @@ -52,8 +52,15 @@ class Sickle(object):
:type protocol_version: str
:param iterator: The type of the returned iterator
(default: :class:`sickle.iterator.OAIItemIterator`)
:param max_retries: Number of retries if HTTP request fails.
:param max_retries: Number of retry attempts if an HTTP request fails (default: 0 = request only once). Sickle will
use the value from the retry-after header (if present) and will wait the specified number of
seconds between retries.
:type max_retries: int
:param retry_status_codes: HTTP status codes to retry (default will only retry on 503)
:type retry_status_codes: iterable
:param default_retry_after: default number of seconds to wait between retries in case no retry-after header is found
on the response (defaults to 60 seconds)
:type default_retry_after: int
:type protocol_version: str
:param class_mapping: A dictionary that maps OAI verbs to classes representing
OAI items. If not provided,
Expand All @@ -73,9 +80,17 @@ class Sickle(object):
for all available parameters.
"""

def __init__(self, endpoint, http_method='GET', protocol_version='2.0',
iterator=OAIItemIterator, max_retries=5,
class_mapping=None, encoding=None, **request_args):
def __init__(self, endpoint,
http_method='GET',
protocol_version='2.0',
iterator=OAIItemIterator,
max_retries=0,
retry_status_codes=None,
default_retry_after=60,
class_mapping=None,
encoding=None,
**request_args):

self.endpoint = endpoint
if http_method not in ['GET', 'POST']:
raise ValueError("Invalid HTTP method: %s! Must be GET or POST.")
Expand All @@ -90,6 +105,8 @@ def __init__(self, endpoint, http_method='GET', protocol_version='2.0',
raise TypeError(
"Argument 'iterator' must be subclass of %s" % BaseOAIIterator.__name__)
self.max_retries = max_retries
self.retry_status_codes = retry_status_codes or [503]
self.default_retry_after = default_retry_after
self.oai_namespace = OAI_NAMESPACE % self.protocol_version
self.class_mapping = class_mapping or DEFAULT_CLASS_MAP
self.encoding = encoding
Expand All @@ -101,26 +118,24 @@ def harvest(self, **kwargs): # pragma: no cover
:param kwargs: OAI HTTP parameters.
:rtype: :class:`sickle.OAIResponse`
"""
http_response = self._request(kwargs)
for _ in range(self.max_retries):
if self.http_method == 'GET':
http_response = requests.get(self.endpoint, params=kwargs,
**self.request_args)
else:
http_response = requests.post(self.endpoint, data=kwargs,
**self.request_args)
if http_response.status_code == 503:
try:
retry_after = int(http_response.headers.get('retry-after'))
except TypeError:
retry_after = 20
logger.info(
"HTTP 503! Retrying after %d seconds..." % retry_after)
if self._is_error_code(http_response.status_code) \
and http_response.status_code in self.retry_status_codes:
retry_after = self.get_retry_after(http_response)
logger.warning(
"HTTP %d! Retrying after %d seconds..." % (http_response.status_code, retry_after))
time.sleep(retry_after)
else:
http_response.raise_for_status()
if self.encoding:
http_response.encoding = self.encoding
return OAIResponse(http_response, params=kwargs)
http_response = self._request(kwargs)
http_response.raise_for_status()
if self.encoding:
http_response.encoding = self.encoding
return OAIResponse(http_response, params=kwargs)

def _request(self, kwargs):
if self.http_method == 'GET':
return requests.get(self.endpoint, params=kwargs, **self.request_args)
return requests.post(self.endpoint, data=kwargs, **self.request_args)

def ListRecords(self, ignore_deleted=False, **kwargs):
"""Issue a ListRecords request.
Expand Down Expand Up @@ -178,3 +193,15 @@ def ListMetadataFormats(self, **kwargs):
params = kwargs
params.update({'verb': 'ListMetadataFormats'})
return self.iterator(self, params)

def get_retry_after(self, http_response):
if http_response.status_code == 503:
try:
return int(http_response.headers.get('retry-after'))
except TypeError:
return self.default_retry_after
return self.default_retry_after

@staticmethod
def _is_error_code(status_code):
return status_code >= 400
55 changes: 52 additions & 3 deletions sickle/tests/test_sickle.py
Expand Up @@ -10,6 +10,8 @@

from mock import patch, Mock
from nose.tools import raises
from requests import HTTPError

from sickle import Sickle

this_dir, this_filename = os.path.split(__file__)
Expand All @@ -29,7 +31,7 @@ def test_invalid_iterator(self):
Sickle("http://localhost", iterator=None)

def test_pass_request_args(self):
mock_response = Mock(text=u'<xml/>', content='<xml/>')
mock_response = Mock(text=u'<xml/>', content='<xml/>', status_code=200)
mock_get = Mock(return_value=mock_response)
with patch('sickle.app.requests.get', mock_get):
sickle = Sickle('url', timeout=10, proxies=dict(),
Expand All @@ -41,9 +43,56 @@ def test_pass_request_args(self):
auth=('user', 'password'))

def test_override_encoding(self):
mock_response = Mock(text='<xml/>', content='<xml/>')
mock_response = Mock(text='<xml/>', content='<xml/>', status_code=200)
mock_get = Mock(return_value=mock_response)
with patch('sickle.app.requests.get', mock_get):
sickle = Sickle('url', encoding='encoding')
sickle.ListSets()
self.assertEqual(mock_response.encoding, 'encoding')
mock_get.assert_called_once_with('url',
params={'verb': 'ListSets'})

def test_no_retry(self):
mock_response = Mock(status_code=503,
headers={'retry-after': '10'},
raise_for_status=Mock(side_effect=HTTPError))
mock_get = Mock(return_value=mock_response)
with patch('sickle.app.requests.get', mock_get):
sickle = Sickle('url')
try:
sickle.ListRecords()
except HTTPError:
pass
self.assertEqual(1, mock_get.call_count)

def test_retry_on_503(self):
mock_response = Mock(status_code=503,
headers={'retry-after': '10'},
raise_for_status=Mock(side_effect=HTTPError))
mock_get = Mock(return_value=mock_response)
sleep_mock = Mock()
with patch('time.sleep', sleep_mock):
with patch('sickle.app.requests.get', mock_get):
sickle = Sickle('url', max_retries=3, default_retry_after=0)
try:
sickle.ListRecords()
except HTTPError:
pass
mock_get.assert_called_with('url',
params={'verb': 'ListRecords'})
self.assertEqual(4, mock_get.call_count)
self.assertEqual(3, sleep_mock.call_count)
sleep_mock.assert_called_with(10)

def test_retry_on_custom_code(self):
mock_response = Mock(status_code=500,
raise_for_status=Mock(side_effect=HTTPError))
mock_get = Mock(return_value=mock_response)
with patch('sickle.app.requests.get', mock_get):
sickle = Sickle('url', max_retries=3, default_retry_after=0, retry_status_codes=(503, 500))
try:
sickle.ListRecords()
except HTTPError:
pass
mock_get.assert_called_with('url',
params={'verb': 'ListRecords'})
self.assertEqual(4, mock_get.call_count)

0 comments on commit b9dbaf4

Please sign in to comment.