Skip to content

Commit

Permalink
fix: Preserve call/errbacks of replaced tasks (celery#6770)
Browse files Browse the repository at this point in the history
* style: Remove unused var from canvas unit tests

* test: Check task ID re-freeze on replacement

* refac: Remove duped task ID preservation logic

* test: Rework canvas call/errback integration tests

This change modifies a bunch of the tests to use unique keys for the
`redis_echo` and `redis_count` tasks which are used to validate that
callbacks and errbacks are made. We also introduce helper functions for
validating that messages/counts are seen to reduce duplicate code.

* fix: Preserve call/errbacks of replaced tasks

Fixes celery#6441

* fix: Ensure replacement tasks get the group index

This change adds some tests to ensure that when a task is replaced, it
runs as expected. This exposed a bug where the group index of a task
would be lost when replaced with a chain since chains would not pass
their `group_index` option down to the final task when applied. This
manifested as the results of chords being mis-ordered on the redis
backend since the group index would default to `+inf`. Other backends
may have had similar issues.
  • Loading branch information
maybe-sybr authored and jeyrce committed Aug 25, 2021
1 parent ee439d8 commit f55b8dd
Show file tree
Hide file tree
Showing 6 changed files with 587 additions and 260 deletions.
57 changes: 28 additions & 29 deletions celery/app/task.py
Expand Up @@ -6,9 +6,9 @@
from kombu.exceptions import OperationalError
from kombu.utils.uuid import uuid

from celery import current_app, group, states
from celery import current_app, states
from celery._state import _task_stack
from celery.canvas import _chain, signature
from celery.canvas import _chain, group, signature
from celery.exceptions import (Ignore, ImproperlyConfigured,
MaxRetriesExceededError, Reject, Retry)
from celery.local import class_property
Expand Down Expand Up @@ -893,41 +893,40 @@ def replace(self, sig):
raise ImproperlyConfigured(
"A signature replacing a task must not be part of a chord"
)
if isinstance(sig, _chain) and not getattr(sig, "tasks", True):
raise ImproperlyConfigured("Cannot replace with an empty chain")

# Ensure callbacks or errbacks from the replaced signature are retained
if isinstance(sig, group):
sig |= self.app.tasks['celery.accumulate'].s(index=0).set(
link=self.request.callbacks,
link_error=self.request.errbacks,
)
elif isinstance(sig, _chain):
if not sig.tasks:
raise ImproperlyConfigured(
"Cannot replace with an empty chain"
)

if self.request.chain:
# We need to freeze the new signature with the current task's ID to
# ensure that we don't disassociate the new chain from the existing
# task IDs which would break previously constructed results
# objects.
sig.freeze(self.request.id)
if "link" in sig.options:
final_task_links = sig.tasks[-1].options.setdefault("link", [])
final_task_links.extend(maybe_list(sig.options["link"]))
# Construct the new remainder of the task by chaining the signature
# we're being replaced by with signatures constructed from the
# chain elements in the current request.
for t in reversed(self.request.chain):
sig |= signature(t, app=self.app)

# Groups get uplifted to a chord so that we can link onto the body
sig |= self.app.tasks['celery.accumulate'].s(index=0)
for callback in maybe_list(self.request.callbacks) or []:
sig.link(callback)
for errback in maybe_list(self.request.errbacks) or []:
sig.link_error(errback)
# If the replacement signature is a chain, we need to push callbacks
# down to the final task so they run at the right time even if we
# proceed to link further tasks from the original request below
if isinstance(sig, _chain) and "link" in sig.options:
final_task_links = sig.tasks[-1].options.setdefault("link", [])
final_task_links.extend(maybe_list(sig.options["link"]))
# We need to freeze the replacement signature with the current task's
# ID to ensure that we don't disassociate it from the existing task IDs
# which would break previously constructed results objects.
sig.freeze(self.request.id)
# Ensure the important options from the original signature are retained
sig.set(
chord=chord,
group_id=self.request.group,
group_index=self.request.group_index,
root_id=self.request.root_id,
)
sig.freeze(self.request.id)

# If the task being replaced is part of a chain, we need to re-create
# it with the replacement signature - these subsequent tasks will
# retain their original task IDs as well
for t in reversed(self.request.chain or []):
sig |= signature(t, app=self.app)
# Finally, either apply or delay the new signature!
if self.request.is_eager:
return sig.apply().get()
else:
Expand Down
5 changes: 3 additions & 2 deletions celery/canvas.py
Expand Up @@ -642,7 +642,8 @@ def apply_async(self, args=None, kwargs=None, **options):

def run(self, args=None, kwargs=None, group_id=None, chord=None,
task_id=None, link=None, link_error=None, publisher=None,
producer=None, root_id=None, parent_id=None, app=None, **options):
producer=None, root_id=None, parent_id=None, app=None,
group_index=None, **options):
# pylint: disable=redefined-outer-name
# XXX chord is also a class in outer scope.
args = args if args else ()
Expand All @@ -656,7 +657,7 @@ def run(self, args=None, kwargs=None, group_id=None, chord=None,

tasks, results_from_prepare = self.prepare_steps(
args, kwargs, self.tasks, root_id, parent_id, link_error, app,
task_id, group_id, chord,
task_id, group_id, chord, group_index=group_index,
)

if results_from_prepare:
Expand Down
16 changes: 11 additions & 5 deletions t/integration/tasks.py
Expand Up @@ -217,17 +217,17 @@ def retry_once_priority(self, *args, expires=60.0, max_retries=1,


@shared_task
def redis_echo(message):
def redis_echo(message, redis_key="redis-echo"):
"""Task that appends the message to a redis list."""
redis_connection = get_redis_connection()
redis_connection.rpush('redis-echo', message)
redis_connection.rpush(redis_key, message)


@shared_task
def redis_count():
"""Task that increments a well-known redis key."""
def redis_count(redis_key="redis-count"):
"""Task that increments a specified or well-known redis key."""
redis_connection = get_redis_connection()
redis_connection.incr('redis-count')
redis_connection.incr(redis_key)


@shared_task(bind=True)
Expand Down Expand Up @@ -295,6 +295,12 @@ def fail(*args):
raise ExpectedException(*args)


@shared_task(bind=True)
def fail_replaced(self, *args):
"""Replace this task with one which raises ExpectedException."""
raise self.replace(fail.si(*args))


@shared_task
def chord_error(*args):
return args
Expand Down

0 comments on commit f55b8dd

Please sign in to comment.