Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Preserve call/errbacks of replaced tasks #6770

Merged
merged 6 commits into from
Jun 15, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
57 changes: 28 additions & 29 deletions celery/app/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@
from kombu.exceptions import OperationalError
from kombu.utils.uuid import uuid

from celery import current_app, group, states
from celery import current_app, states
from celery._state import _task_stack
from celery.canvas import _chain, signature
from celery.canvas import _chain, group, signature
from celery.exceptions import (Ignore, ImproperlyConfigured,
MaxRetriesExceededError, Reject, Retry)
from celery.local import class_property
Expand Down Expand Up @@ -893,41 +893,40 @@ def replace(self, sig):
raise ImproperlyConfigured(
"A signature replacing a task must not be part of a chord"
)
if isinstance(sig, _chain) and not getattr(sig, "tasks", True):
raise ImproperlyConfigured("Cannot replace with an empty chain")
maybe-sybr marked this conversation as resolved.
Show resolved Hide resolved

# Ensure callbacks or errbacks from the replaced signature are retained
if isinstance(sig, group):
sig |= self.app.tasks['celery.accumulate'].s(index=0).set(
link=self.request.callbacks,
link_error=self.request.errbacks,
)
elif isinstance(sig, _chain):
if not sig.tasks:
raise ImproperlyConfigured(
"Cannot replace with an empty chain"
)

if self.request.chain:
# We need to freeze the new signature with the current task's ID to
# ensure that we don't disassociate the new chain from the existing
# task IDs which would break previously constructed results
# objects.
sig.freeze(self.request.id)
maybe-sybr marked this conversation as resolved.
Show resolved Hide resolved
if "link" in sig.options:
final_task_links = sig.tasks[-1].options.setdefault("link", [])
final_task_links.extend(maybe_list(sig.options["link"]))
# Construct the new remainder of the task by chaining the signature
# we're being replaced by with signatures constructed from the
# chain elements in the current request.
for t in reversed(self.request.chain):
sig |= signature(t, app=self.app)

# Groups get uplifted to a chord so that we can link onto the body
sig |= self.app.tasks['celery.accumulate'].s(index=0)
for callback in maybe_list(self.request.callbacks) or []:
sig.link(callback)
for errback in maybe_list(self.request.errbacks) or []:
sig.link_error(errback)
# If the replacement signature is a chain, we need to push callbacks
# down to the final task so they run at the right time even if we
# proceed to link further tasks from the original request below
if isinstance(sig, _chain) and "link" in sig.options:
final_task_links = sig.tasks[-1].options.setdefault("link", [])
final_task_links.extend(maybe_list(sig.options["link"]))
# We need to freeze the replacement signature with the current task's
# ID to ensure that we don't disassociate it from the existing task IDs
# which would break previously constructed results objects.
sig.freeze(self.request.id)
# Ensure the important options from the original signature are retained
sig.set(
chord=chord,
group_id=self.request.group,
group_index=self.request.group_index,
root_id=self.request.root_id,
)
sig.freeze(self.request.id)

# If the task being replaced is part of a chain, we need to re-create
# it with the replacement signature - these subsequent tasks will
# retain their original task IDs as well
for t in reversed(self.request.chain or []):
sig |= signature(t, app=self.app)
# Finally, either apply or delay the new signature!
if self.request.is_eager:
return sig.apply().get()
else:
Expand Down
5 changes: 3 additions & 2 deletions celery/canvas.py
Original file line number Diff line number Diff line change
Expand Up @@ -642,7 +642,8 @@ def apply_async(self, args=None, kwargs=None, **options):

def run(self, args=None, kwargs=None, group_id=None, chord=None,
task_id=None, link=None, link_error=None, publisher=None,
producer=None, root_id=None, parent_id=None, app=None, **options):
producer=None, root_id=None, parent_id=None, app=None,
group_index=None, **options):
# pylint: disable=redefined-outer-name
# XXX chord is also a class in outer scope.
args = args if args else ()
Expand All @@ -656,7 +657,7 @@ def run(self, args=None, kwargs=None, group_id=None, chord=None,

tasks, results_from_prepare = self.prepare_steps(
args, kwargs, self.tasks, root_id, parent_id, link_error, app,
task_id, group_id, chord,
task_id, group_id, chord, group_index=group_index,
)

if results_from_prepare:
Expand Down
16 changes: 11 additions & 5 deletions t/integration/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,17 +217,17 @@ def retry_once_priority(self, *args, expires=60.0, max_retries=1,


@shared_task
def redis_echo(message):
def redis_echo(message, redis_key="redis-echo"):
"""Task that appends the message to a redis list."""
redis_connection = get_redis_connection()
redis_connection.rpush('redis-echo', message)
redis_connection.rpush(redis_key, message)


@shared_task
def redis_count():
"""Task that increments a well-known redis key."""
def redis_count(redis_key="redis-count"):
"""Task that increments a specified or well-known redis key."""
redis_connection = get_redis_connection()
redis_connection.incr('redis-count')
redis_connection.incr(redis_key)


@shared_task(bind=True)
Expand Down Expand Up @@ -295,6 +295,12 @@ def fail(*args):
raise ExpectedException(*args)


@shared_task(bind=True)
def fail_replaced(self, *args):
"""Replace this task with one which raises ExpectedException."""
raise self.replace(fail.si(*args))


@shared_task
def chord_error(*args):
return args
Expand Down