Skip to content

Commit

Permalink
Task class definitions can have retry attributes (#5869)
Browse files Browse the repository at this point in the history
* autoretry_for
* retry_kwargs
* retry_backoff
* retry_backoff_max
* retry_jitter
can now be defined as cls attributes.

All of these can be overriden from the @task decorator

#4684
  • Loading branch information
marcosmoyano authored and auvipy committed Dec 14, 2019
1 parent 984a45b commit a7c74d7
Show file tree
Hide file tree
Showing 2 changed files with 189 additions and 5 deletions.
23 changes: 18 additions & 5 deletions celery/app/base.py
Expand Up @@ -460,11 +460,24 @@ def _task_from_fun(self, fun, name=None, base=None, bind=False, **options):
self._tasks[task.name] = task
task.bind(self) # connects task to this app

autoretry_for = tuple(options.get('autoretry_for', ()))
retry_kwargs = options.get('retry_kwargs', {})
retry_backoff = int(options.get('retry_backoff', False))
retry_backoff_max = int(options.get('retry_backoff_max', 600))
retry_jitter = options.get('retry_jitter', True)
autoretry_for = tuple(
options.get('autoretry_for',
getattr(task, 'autoretry_for', ()))
)
retry_kwargs = options.get(
'retry_kwargs', getattr(task, 'retry_kwargs', {})
)
retry_backoff = int(
options.get('retry_backoff',
getattr(task, 'retry_backoff', False))
)
retry_backoff_max = int(
options.get('retry_backoff_max',
getattr(task, 'retry_backoff_max', 600))
)
retry_jitter = options.get(
'retry_jitter', getattr(task, 'retry_jitter', True)
)

if autoretry_for and not hasattr(task, '_orig_run'):

Expand Down
171 changes: 171 additions & 0 deletions t/unit/tasks/test_tasks.py
Expand Up @@ -43,6 +43,14 @@ class TaskWithPriority(Task):
priority = 10


class TaskWithRetry(Task):
autoretry_for = (TypeError,)
retry_kwargs = {'max_retries': 5}
retry_backoff = True
retry_backoff_max = 700
retry_jitter = False


class TasksCase:

def setup(self):
Expand Down Expand Up @@ -152,6 +160,81 @@ def autoretry_backoff_jitter_task(self, url):

self.autoretry_backoff_jitter_task = autoretry_backoff_jitter_task

@self.app.task(bind=True, base=TaskWithRetry, shared=False)
def autoretry_for_from_base_task(self, a, b):
self.iterations += 1
return a + b

self.autoretry_for_from_base_task = autoretry_for_from_base_task

@self.app.task(bind=True, base=TaskWithRetry,
autoretry_for=(ZeroDivisionError,), shared=False)
def override_autoretry_for_from_base_task(self, a, b):
self.iterations += 1
return a / b

self.override_autoretry_for = override_autoretry_for_from_base_task

@self.app.task(bind=True, base=TaskWithRetry, shared=False)
def retry_kwargs_from_base_task(self, a, b):
self.iterations += 1
return a + b

self.retry_kwargs_from_base_task = retry_kwargs_from_base_task

@self.app.task(bind=True, base=TaskWithRetry,
retry_kwargs={'max_retries': 2}, shared=False)
def override_retry_kwargs_from_base_task(self, a, b):
self.iterations += 1
return a + b

self.override_retry_kwargs = override_retry_kwargs_from_base_task

@self.app.task(bind=True, base=TaskWithRetry, shared=False)
def retry_backoff_from_base_task(self, a, b):
self.iterations += 1
return a + b

self.retry_backoff_from_base_task = retry_backoff_from_base_task

@self.app.task(bind=True, base=TaskWithRetry,
retry_backoff=False, shared=False)
def override_retry_backoff_from_base_task(self, a, b):
self.iterations += 1
return a + b

self.override_retry_backoff = override_retry_backoff_from_base_task

@self.app.task(bind=True, base=TaskWithRetry, shared=False)
def retry_backoff_max_from_base_task(self, a, b):
self.iterations += 1
return a + b

self.retry_backoff_max_from_base_task = retry_backoff_max_from_base_task

@self.app.task(bind=True, base=TaskWithRetry,
retry_backoff_max=16, shared=False)
def override_retry_backoff_max_from_base_task(self, a, b):
self.iterations += 1
return a + b

self.override_backoff_max = override_retry_backoff_max_from_base_task

@self.app.task(bind=True, base=TaskWithRetry, shared=False)
def retry_backoff_jitter_from_base_task(self, a, b):
self.iterations += 1
return a + b

self.retry_backoff_jitter_from_base = retry_backoff_jitter_from_base_task

@self.app.task(bind=True, base=TaskWithRetry,
retry_jitter=True, shared=False)
def override_backoff_jitter_from_base_task(self, a, b):
self.iterations += 1
return a + b

self.override_backoff_jitter = override_backoff_jitter_from_base_task

@self.app.task(bind=True)
def task_check_request_context(self):
assert self.request.hostname == socket.gethostname()
Expand Down Expand Up @@ -373,6 +456,94 @@ def test_autoretry_backoff_jitter(self, randrange):
]
assert retry_call_countdowns == [0, 1, 3, 7]

def test_autoretry_for_from_base(self):
self.autoretry_for_from_base_task.iterations = 0
self.autoretry_for_from_base_task.apply((1, "a"))
assert self.autoretry_for_from_base_task.iterations == 6

def test_override_autoretry_for_from_base(self):
self.override_autoretry_for.iterations = 0
self.override_autoretry_for.apply((1, 0))
assert self.override_autoretry_for.iterations == 6

def test_retry_kwargs_from_base(self):
self.retry_kwargs_from_base_task.iterations = 0
self.retry_kwargs_from_base_task.apply((1, "a"))
assert self.retry_kwargs_from_base_task.iterations == 6

def test_override_retry_kwargs_from_base(self):
self.override_retry_kwargs.iterations = 0
self.override_retry_kwargs.apply((1, "a"))
assert self.override_retry_kwargs.iterations == 3

def test_retry_backoff_from_base(self):
task = self.retry_backoff_from_base_task
task.iterations = 0
with patch.object(task, 'retry', wraps=task.retry) as fake_retry:
task.apply((1, "a"))

assert task.iterations == 6
retry_call_countdowns = [
call[1]['countdown'] for call in fake_retry.call_args_list
]
assert retry_call_countdowns == [1, 2, 4, 8, 16, 32]

@patch('celery.app.base.get_exponential_backoff_interval')
def test_override_retry_backoff_from_base(self, backoff):
self.override_retry_backoff.iterations = 0
self.override_retry_backoff.apply((1, "a"))
assert self.override_retry_backoff.iterations == 6
assert backoff.call_count == 0

def test_retry_backoff_max_from_base(self):
task = self.retry_backoff_max_from_base_task
task.iterations = 0
with patch.object(task, 'retry', wraps=task.retry) as fake_retry:
task.apply((1, "a"))

assert task.iterations == 6
retry_call_countdowns = [
call[1]['countdown'] for call in fake_retry.call_args_list
]
assert retry_call_countdowns == [1, 2, 4, 8, 16, 32]

def test_override_retry_backoff_max_from_base(self):
task = self.override_backoff_max
task.iterations = 0
with patch.object(task, 'retry', wraps=task.retry) as fake_retry:
task.apply((1, "a"))

assert task.iterations == 6
retry_call_countdowns = [
call[1]['countdown'] for call in fake_retry.call_args_list
]
assert retry_call_countdowns == [1, 2, 4, 8, 16, 16]

def test_retry_backoff_jitter_from_base(self):
task = self.retry_backoff_jitter_from_base
task.iterations = 0
with patch.object(task, 'retry', wraps=task.retry) as fake_retry:
task.apply((1, "a"))

assert task.iterations == 6
retry_call_countdowns = [
call[1]['countdown'] for call in fake_retry.call_args_list
]
assert retry_call_countdowns == [1, 2, 4, 8, 16, 32]

@patch('random.randrange', side_effect=lambda i: i - 2)
def test_override_backoff_jitter_from_base(self, randrange):
task = self.override_backoff_jitter
task.iterations = 0
with patch.object(task, 'retry', wraps=task.retry) as fake_retry:
task.apply((1, "a"))

assert task.iterations == 6
retry_call_countdowns = [
call[1]['countdown'] for call in fake_retry.call_args_list
]
assert retry_call_countdowns == [0, 1, 3, 7, 15, 31]

def test_retry_wrong_eta_when_not_enable_utc(self):
"""Issue #3753"""
self.app.conf.enable_utc = False
Expand Down

0 comments on commit a7c74d7

Please sign in to comment.