diff --git a/celery/app/task.py b/celery/app/task.py index f8ffaefaffd..fe9c039c0ba 100644 --- a/celery/app/task.py +++ b/celery/app/task.py @@ -8,7 +8,7 @@ from celery import current_app, group, states from celery._state import _task_stack -from celery.canvas import signature +from celery.canvas import _chain, signature from celery.exceptions import (Ignore, ImproperlyConfigured, MaxRetriesExceededError, Reject, Retry) from celery.local import class_property @@ -880,6 +880,11 @@ def replace(self, sig): link=self.request.callbacks, link_error=self.request.errbacks, ) + elif isinstance(sig, _chain): + if not sig.tasks: + raise ImproperlyConfigured( + "Cannot replace with an empty chain" + ) if self.request.chain: # We need to freeze the new signature with the current task's ID to diff --git a/t/integration/tasks.py b/t/integration/tasks.py index 8aa13bc1797..1aaeed32378 100644 --- a/t/integration/tasks.py +++ b/t/integration/tasks.py @@ -100,6 +100,11 @@ def replace_with_chain_which_raises(self, *args, link_msg=None): return self.replace(c) +@shared_task(bind=True) +def replace_with_empty_chain(self, *_): + return self.replace(chain()) + + @shared_task(bind=True) def add_to_all(self, nums, val): """Add the given value to all supplied numbers.""" diff --git a/t/integration/test_canvas.py b/t/integration/test_canvas.py index 4ae027fb10a..97a6b58d112 100644 --- a/t/integration/test_canvas.py +++ b/t/integration/test_canvas.py @@ -6,7 +6,7 @@ from celery import chain, chord, group, signature from celery.backends.base import BaseKeyValueStoreBackend -from celery.exceptions import TimeoutError +from celery.exceptions import ImproperlyConfigured, TimeoutError from celery.result import AsyncResult, GroupResult, ResultSet from . import tasks @@ -15,9 +15,10 @@ add_to_all, add_to_all_to_chord, build_chain_inside_task, chord_error, collect_ids, delayed_sum, delayed_sum_with_soft_guard, fail, identity, ids, - print_unicode, raise_error, redis_echo, retry_once, - return_exception, return_priority, second_order_replace1, - tsum, replace_with_chain, replace_with_chain_which_raises) + print_unicode, raise_error, redis_echo, + replace_with_chain, replace_with_chain_which_raises, + replace_with_empty_chain, retry_once, return_exception, + return_priority, second_order_replace1, tsum) RETRYABLE_EXCEPTIONS = (OSError, ConnectionError, TimeoutError) @@ -584,6 +585,13 @@ def test_chain_with_eb_replaced_with_chain_with_eb(self, manager): assert redis_connection.blpop('redis-echo', min(1, TIMEOUT)) is None redis_connection.delete('redis-echo') + def test_replace_chain_with_empty_chain(self, manager): + r = chain(identity.s(1), replace_with_empty_chain.s()).delay() + + with pytest.raises(ImproperlyConfigured, + match="Cannot replace with an empty chain"): + r.get(timeout=TIMEOUT) + class test_result_set: