From f55b8ddc8feef9700eec2cfa048070b45895c3fb Mon Sep 17 00:00:00 2001 From: maybe-sybr <58414429+maybe-sybr@users.noreply.github.com> Date: Tue, 15 Jun 2021 14:24:48 +1000 Subject: [PATCH] fix: Preserve call/errbacks of replaced tasks (#6770) * style: Remove unused var from canvas unit tests * test: Check task ID re-freeze on replacement * refac: Remove duped task ID preservation logic * test: Rework canvas call/errback integration tests This change modifies a bunch of the tests to use unique keys for the `redis_echo` and `redis_count` tasks which are used to validate that callbacks and errbacks are made. We also introduce helper functions for validating that messages/counts are seen to reduce duplicate code. * fix: Preserve call/errbacks of replaced tasks Fixes #6441 * fix: Ensure replacement tasks get the group index This change adds some tests to ensure that when a task is replaced, it runs as expected. This exposed a bug where the group index of a task would be lost when replaced with a chain since chains would not pass their `group_index` option down to the final task when applied. This manifested as the results of chords being mis-ordered on the redis backend since the group index would default to `+inf`. Other backends may have had similar issues. --- celery/app/task.py | 57 ++- celery/canvas.py | 5 +- t/integration/tasks.py | 16 +- t/integration/test_canvas.py | 720 +++++++++++++++++++++++++---------- t/unit/tasks/test_canvas.py | 2 +- t/unit/tasks/test_tasks.py | 47 +-- 6 files changed, 587 insertions(+), 260 deletions(-) diff --git a/celery/app/task.py b/celery/app/task.py index 78025cc513a..1e50e613b58 100644 --- a/celery/app/task.py +++ b/celery/app/task.py @@ -6,9 +6,9 @@ from kombu.exceptions import OperationalError from kombu.utils.uuid import uuid -from celery import current_app, group, states +from celery import current_app, states from celery._state import _task_stack -from celery.canvas import _chain, signature +from celery.canvas import _chain, group, signature from celery.exceptions import (Ignore, ImproperlyConfigured, MaxRetriesExceededError, Reject, Retry) from celery.local import class_property @@ -893,41 +893,40 @@ def replace(self, sig): raise ImproperlyConfigured( "A signature replacing a task must not be part of a chord" ) + if isinstance(sig, _chain) and not getattr(sig, "tasks", True): + raise ImproperlyConfigured("Cannot replace with an empty chain") + # Ensure callbacks or errbacks from the replaced signature are retained if isinstance(sig, group): - sig |= self.app.tasks['celery.accumulate'].s(index=0).set( - link=self.request.callbacks, - link_error=self.request.errbacks, - ) - elif isinstance(sig, _chain): - if not sig.tasks: - raise ImproperlyConfigured( - "Cannot replace with an empty chain" - ) - - if self.request.chain: - # We need to freeze the new signature with the current task's ID to - # ensure that we don't disassociate the new chain from the existing - # task IDs which would break previously constructed results - # objects. - sig.freeze(self.request.id) - if "link" in sig.options: - final_task_links = sig.tasks[-1].options.setdefault("link", []) - final_task_links.extend(maybe_list(sig.options["link"])) - # Construct the new remainder of the task by chaining the signature - # we're being replaced by with signatures constructed from the - # chain elements in the current request. - for t in reversed(self.request.chain): - sig |= signature(t, app=self.app) - + # Groups get uplifted to a chord so that we can link onto the body + sig |= self.app.tasks['celery.accumulate'].s(index=0) + for callback in maybe_list(self.request.callbacks) or []: + sig.link(callback) + for errback in maybe_list(self.request.errbacks) or []: + sig.link_error(errback) + # If the replacement signature is a chain, we need to push callbacks + # down to the final task so they run at the right time even if we + # proceed to link further tasks from the original request below + if isinstance(sig, _chain) and "link" in sig.options: + final_task_links = sig.tasks[-1].options.setdefault("link", []) + final_task_links.extend(maybe_list(sig.options["link"])) + # We need to freeze the replacement signature with the current task's + # ID to ensure that we don't disassociate it from the existing task IDs + # which would break previously constructed results objects. + sig.freeze(self.request.id) + # Ensure the important options from the original signature are retained sig.set( chord=chord, group_id=self.request.group, group_index=self.request.group_index, root_id=self.request.root_id, ) - sig.freeze(self.request.id) - + # If the task being replaced is part of a chain, we need to re-create + # it with the replacement signature - these subsequent tasks will + # retain their original task IDs as well + for t in reversed(self.request.chain or []): + sig |= signature(t, app=self.app) + # Finally, either apply or delay the new signature! if self.request.is_eager: return sig.apply().get() else: diff --git a/celery/canvas.py b/celery/canvas.py index 9b32e832fd0..fb9c9640399 100644 --- a/celery/canvas.py +++ b/celery/canvas.py @@ -642,7 +642,8 @@ def apply_async(self, args=None, kwargs=None, **options): def run(self, args=None, kwargs=None, group_id=None, chord=None, task_id=None, link=None, link_error=None, publisher=None, - producer=None, root_id=None, parent_id=None, app=None, **options): + producer=None, root_id=None, parent_id=None, app=None, + group_index=None, **options): # pylint: disable=redefined-outer-name # XXX chord is also a class in outer scope. args = args if args else () @@ -656,7 +657,7 @@ def run(self, args=None, kwargs=None, group_id=None, chord=None, tasks, results_from_prepare = self.prepare_steps( args, kwargs, self.tasks, root_id, parent_id, link_error, app, - task_id, group_id, chord, + task_id, group_id, chord, group_index=group_index, ) if results_from_prepare: diff --git a/t/integration/tasks.py b/t/integration/tasks.py index d1b825fcf53..2cbe534fa4c 100644 --- a/t/integration/tasks.py +++ b/t/integration/tasks.py @@ -217,17 +217,17 @@ def retry_once_priority(self, *args, expires=60.0, max_retries=1, @shared_task -def redis_echo(message): +def redis_echo(message, redis_key="redis-echo"): """Task that appends the message to a redis list.""" redis_connection = get_redis_connection() - redis_connection.rpush('redis-echo', message) + redis_connection.rpush(redis_key, message) @shared_task -def redis_count(): - """Task that increments a well-known redis key.""" +def redis_count(redis_key="redis-count"): + """Task that increments a specified or well-known redis key.""" redis_connection = get_redis_connection() - redis_connection.incr('redis-count') + redis_connection.incr(redis_key) @shared_task(bind=True) @@ -295,6 +295,12 @@ def fail(*args): raise ExpectedException(*args) +@shared_task(bind=True) +def fail_replaced(self, *args): + """Replace this task with one which raises ExpectedException.""" + raise self.replace(fail.si(*args)) + + @shared_task def chord_error(*args): return args diff --git a/t/integration/test_canvas.py b/t/integration/test_canvas.py index 02beb8550d4..267fa6e1adb 100644 --- a/t/integration/test_canvas.py +++ b/t/integration/test_canvas.py @@ -1,3 +1,4 @@ +import collections import re import tempfile import uuid @@ -18,12 +19,12 @@ from .tasks import (ExpectedException, add, add_chord_to_chord, add_replaced, add_to_all, add_to_all_to_chord, build_chain_inside_task, chord_error, collect_ids, delayed_sum, - delayed_sum_with_soft_guard, fail, identity, ids, - print_unicode, raise_error, redis_count, redis_echo, - replace_with_chain, replace_with_chain_which_raises, - replace_with_empty_chain, retry_once, return_exception, - return_priority, second_order_replace1, tsum, - write_to_file_and_return_int) + delayed_sum_with_soft_guard, fail, fail_replaced, + identity, ids, print_unicode, raise_error, redis_count, + redis_echo, replace_with_chain, + replace_with_chain_which_raises, replace_with_empty_chain, + retry_once, return_exception, return_priority, + second_order_replace1, tsum, write_to_file_and_return_int) RETRYABLE_EXCEPTIONS = (OSError, ConnectionError, TimeoutError) @@ -43,6 +44,62 @@ def flaky(fn): return _timeout(_flaky(fn)) +def await_redis_echo(expected_msgs, redis_key="redis-echo", timeout=TIMEOUT): + """ + Helper to wait for a specified or well-known redis key to contain a string. + """ + redis_connection = get_redis_connection() + + if isinstance(expected_msgs, (str, bytes, bytearray)): + expected_msgs = (expected_msgs, ) + expected_msgs = collections.Counter( + e if not isinstance(e, str) else e.encode("utf-8") + for e in expected_msgs + ) + + # This can technically wait for `len(expected_msg_or_msgs) * timeout` :/ + while +expected_msgs: + maybe_key_msg = redis_connection.blpop(redis_key, timeout) + if maybe_key_msg is None: + raise TimeoutError( + "Fetching from {!r} timed out - still awaiting {!r}" + .format(redis_key, dict(+expected_msgs)) + ) + retrieved_key, msg = maybe_key_msg + assert retrieved_key.decode("utf-8") == redis_key + expected_msgs[msg] -= 1 # silently accepts unexpected messages + + # There should be no more elements - block momentarily + assert redis_connection.blpop(redis_key, min(1, timeout)) is None + + +def await_redis_count(expected_count, redis_key="redis-count", timeout=TIMEOUT): + """ + Helper to wait for a specified or well-known redis key to count to a value. + """ + redis_connection = get_redis_connection() + + check_interval = 0.1 + check_max = int(timeout / check_interval) + for i in range(check_max + 1): + maybe_count = redis_connection.get(redis_key) + # It's either `None` or a base-10 integer + if maybe_count is not None: + count = int(maybe_count) + if count == expected_count: + break + elif i >= check_max: + assert count == expected_count + # try again later + sleep(check_interval) + else: + raise TimeoutError("{!r} was never incremented".format(redis_key)) + + # There should be no more increments - block momentarily + sleep(min(1, timeout)) + assert int(redis_connection.get(redis_key)) == expected_count + + class test_link_error: @flaky def test_link_error_eager(self): @@ -476,19 +533,7 @@ def test_chain_replaced_with_a_chain_and_a_callback(self, manager): res = c.delay() assert res.get(timeout=TIMEOUT) == 'Hello world' - - expected_msgs = {link_msg, } - while expected_msgs: - maybe_key_msg = redis_connection.blpop('redis-echo', TIMEOUT) - if maybe_key_msg is None: - raise TimeoutError('redis-echo') - _, msg = maybe_key_msg - msg = msg.decode() - expected_msgs.remove(msg) # KeyError if `msg` is not in here - - # There should be no more elements - block momentarily - assert redis_connection.blpop('redis-echo', min(1, TIMEOUT)) is None - redis_connection.delete('redis-echo') + await_redis_echo({link_msg, }) def test_chain_replaced_with_a_chain_and_an_error_callback(self, manager): if not manager.app.conf.result_backend.startswith('redis'): @@ -507,19 +552,7 @@ def test_chain_replaced_with_a_chain_and_an_error_callback(self, manager): with pytest.raises(ValueError): res.get(timeout=TIMEOUT) - - expected_msgs = {link_msg, } - while expected_msgs: - maybe_key_msg = redis_connection.blpop('redis-echo', TIMEOUT) - if maybe_key_msg is None: - raise TimeoutError('redis-echo') - _, msg = maybe_key_msg - msg = msg.decode() - expected_msgs.remove(msg) # KeyError if `msg` is not in here - - # There should be no more elements - block momentarily - assert redis_connection.blpop('redis-echo', min(1, TIMEOUT)) is None - redis_connection.delete('redis-echo') + await_redis_echo({link_msg, }) def test_chain_with_cb_replaced_with_chain_with_cb(self, manager): if not manager.app.conf.result_backend.startswith('redis'): @@ -539,22 +572,11 @@ def test_chain_with_cb_replaced_with_chain_with_cb(self, manager): res = c.delay() assert res.get(timeout=TIMEOUT) == 'Hello world' + await_redis_echo({link_msg, 'Hello world'}) - expected_msgs = {link_msg, 'Hello world'} - while expected_msgs: - maybe_key_msg = redis_connection.blpop('redis-echo', TIMEOUT) - if maybe_key_msg is None: - raise TimeoutError('redis-echo') - _, msg = maybe_key_msg - msg = msg.decode() - expected_msgs.remove(msg) # KeyError if `msg` is not in here - - # There should be no more elements - block momentarily - assert redis_connection.blpop('redis-echo', min(1, TIMEOUT)) is None - redis_connection.delete('redis-echo') - - @pytest.mark.xfail(reason="#6441") - def test_chain_with_eb_replaced_with_chain_with_eb(self, manager): + def test_chain_with_eb_replaced_with_chain_with_eb( + self, manager, subtests + ): if not manager.app.conf.result_backend.startswith('redis'): raise pytest.skip('Requires redis result backend.') @@ -565,30 +587,18 @@ def test_chain_with_eb_replaced_with_chain_with_eb(self, manager): outer_link_msg = 'External chain errback' c = chain( identity.s('Hello '), - # The replacement chain will pass its args though + # The replacement chain will die and break the encapsulating chain replace_with_chain_which_raises.s(link_msg=inner_link_msg), add.s('world'), ) - c.link_error(redis_echo.s(outer_link_msg)) + c.link_error(redis_echo.si(outer_link_msg)) res = c.delay() - with pytest.raises(ValueError): - res.get(timeout=TIMEOUT) - - expected_msgs = {inner_link_msg, outer_link_msg} - while expected_msgs: - # Shorter timeout here because we expect failure - timeout = min(5, TIMEOUT) - maybe_key_msg = redis_connection.blpop('redis-echo', timeout) - if maybe_key_msg is None: - raise TimeoutError('redis-echo') - _, msg = maybe_key_msg - msg = msg.decode() - expected_msgs.remove(msg) # KeyError if `msg` is not in here - - # There should be no more elements - block momentarily - assert redis_connection.blpop('redis-echo', min(1, TIMEOUT)) is None - redis_connection.delete('redis-echo') + with subtests.test(msg="Chain fails due to a child task dying"): + with pytest.raises(ValueError): + res.get(timeout=TIMEOUT) + with subtests.test(msg="Chain and child task callbacks are called"): + await_redis_echo({inner_link_msg, outer_link_msg}) def test_replace_chain_with_empty_chain(self, manager): r = chain(identity.s(1), replace_with_empty_chain.s()).delay() @@ -597,6 +607,152 @@ def test_replace_chain_with_empty_chain(self, manager): match="Cannot replace with an empty chain"): r.get(timeout=TIMEOUT) + def test_chain_children_with_callbacks(self, manager, subtests): + if not manager.app.conf.result_backend.startswith("redis"): + raise pytest.skip("Requires redis result backend.") + redis_connection = get_redis_connection() + + redis_key = str(uuid.uuid4()) + callback = redis_count.si(redis_key=redis_key) + + child_task_count = 42 + child_sig = identity.si(1337) + child_sig.link(callback) + chain_sig = chain(child_sig for _ in range(child_task_count)) + + redis_connection.delete(redis_key) + with subtests.test(msg="Chain executes as expected"): + res_obj = chain_sig() + assert res_obj.get(timeout=TIMEOUT) == 1337 + with subtests.test(msg="Chain child task callbacks are called"): + await_redis_count(child_task_count, redis_key=redis_key) + redis_connection.delete(redis_key) + + def test_chain_children_with_errbacks(self, manager, subtests): + if not manager.app.conf.result_backend.startswith("redis"): + raise pytest.skip("Requires redis result backend.") + redis_connection = get_redis_connection() + + redis_key = str(uuid.uuid4()) + errback = redis_count.si(redis_key=redis_key) + + child_task_count = 42 + child_sig = fail.si() + child_sig.link_error(errback) + chain_sig = chain(child_sig for _ in range(child_task_count)) + + redis_connection.delete(redis_key) + with subtests.test(msg="Chain fails due to a child task dying"): + res_obj = chain_sig() + with pytest.raises(ExpectedException): + res_obj.get(timeout=TIMEOUT) + with subtests.test(msg="Chain child task errbacks are called"): + # Only the first child task gets a change to run and fail + await_redis_count(1, redis_key=redis_key) + redis_connection.delete(redis_key) + + def test_chain_with_callback_child_replaced(self, manager, subtests): + if not manager.app.conf.result_backend.startswith("redis"): + raise pytest.skip("Requires redis result backend.") + redis_connection = get_redis_connection() + + redis_key = str(uuid.uuid4()) + callback = redis_count.si(redis_key=redis_key) + + chain_sig = chain(add_replaced.si(42, 1337), identity.s()) + chain_sig.link(callback) + + redis_connection.delete(redis_key) + with subtests.test(msg="Chain executes as expected"): + res_obj = chain_sig() + assert res_obj.get(timeout=TIMEOUT) == 42 + 1337 + with subtests.test(msg="Callback is called after chain finishes"): + await_redis_count(1, redis_key=redis_key) + redis_connection.delete(redis_key) + + def test_chain_with_errback_child_replaced(self, manager, subtests): + if not manager.app.conf.result_backend.startswith("redis"): + raise pytest.skip("Requires redis result backend.") + redis_connection = get_redis_connection() + + redis_key = str(uuid.uuid4()) + errback = redis_count.si(redis_key=redis_key) + + chain_sig = chain(add_replaced.si(42, 1337), fail.s()) + chain_sig.link_error(errback) + + redis_connection.delete(redis_key) + with subtests.test(msg="Chain executes as expected"): + res_obj = chain_sig() + with pytest.raises(ExpectedException): + res_obj.get(timeout=TIMEOUT) + with subtests.test(msg="Errback is called after chain finishes"): + await_redis_count(1, redis_key=redis_key) + redis_connection.delete(redis_key) + + def test_chain_child_with_callback_replaced(self, manager, subtests): + if not manager.app.conf.result_backend.startswith("redis"): + raise pytest.skip("Requires redis result backend.") + redis_connection = get_redis_connection() + + redis_key = str(uuid.uuid4()) + callback = redis_count.si(redis_key=redis_key) + + child_sig = add_replaced.si(42, 1337) + child_sig.link(callback) + chain_sig = chain(child_sig, identity.s()) + + redis_connection.delete(redis_key) + with subtests.test(msg="Chain executes as expected"): + res_obj = chain_sig() + assert res_obj.get(timeout=TIMEOUT) == 42 + 1337 + with subtests.test(msg="Callback is called after chain finishes"): + await_redis_count(1, redis_key=redis_key) + redis_connection.delete(redis_key) + + def test_chain_child_with_errback_replaced(self, manager, subtests): + if not manager.app.conf.result_backend.startswith("redis"): + raise pytest.skip("Requires redis result backend.") + redis_connection = get_redis_connection() + + redis_key = str(uuid.uuid4()) + errback = redis_count.si(redis_key=redis_key) + + child_sig = fail_replaced.si() + child_sig.link_error(errback) + chain_sig = chain(child_sig, identity.si(42)) + + redis_connection.delete(redis_key) + with subtests.test(msg="Chain executes as expected"): + res_obj = chain_sig() + with pytest.raises(ExpectedException): + res_obj.get(timeout=TIMEOUT) + with subtests.test(msg="Errback is called after chain finishes"): + await_redis_count(1, redis_key=redis_key) + redis_connection.delete(redis_key) + + def test_task_replaced_with_chain(self): + orig_sig = replace_with_chain.si(42) + res_obj = orig_sig.delay() + assert res_obj.get(timeout=TIMEOUT) == 42 + + def test_chain_child_replaced_with_chain_first(self): + orig_sig = chain(replace_with_chain.si(42), identity.s()) + res_obj = orig_sig.delay() + assert res_obj.get(timeout=TIMEOUT) == 42 + + def test_chain_child_replaced_with_chain_middle(self): + orig_sig = chain( + identity.s(42), replace_with_chain.s(), identity.s() + ) + res_obj = orig_sig.delay() + assert res_obj.get(timeout=TIMEOUT) == 42 + + def test_chain_child_replaced_with_chain_last(self): + orig_sig = chain(identity.s(42), replace_with_chain.s()) + res_obj = orig_sig.delay() + assert res_obj.get(timeout=TIMEOUT) == 42 + class test_result_set: @@ -818,20 +974,18 @@ def test_callback_called_by_group(self, manager, subtests): redis_connection = get_redis_connection() callback_msg = str(uuid.uuid4()).encode() - callback = redis_echo.si(callback_msg) + redis_key = str(uuid.uuid4()) + callback = redis_echo.si(callback_msg, redis_key=redis_key) group_sig = group(identity.si(42), identity.si(1337)) group_sig.link(callback) - redis_connection.delete("redis-echo") + redis_connection.delete(redis_key) with subtests.test(msg="Group result is returned"): res = group_sig.delay() assert res.get(timeout=TIMEOUT) == [42, 1337] with subtests.test(msg="Callback is called after group is completed"): - maybe_key_msg = redis_connection.blpop("redis-echo", TIMEOUT) - if maybe_key_msg is None: - raise TimeoutError("Callback was not called in time") - _, msg = maybe_key_msg - assert msg == callback_msg + await_redis_echo({callback_msg, }, redis_key=redis_key) + redis_connection.delete(redis_key) def test_errback_called_by_group_fail_first(self, manager, subtests): if not manager.app.conf.result_backend.startswith("redis"): @@ -839,21 +993,19 @@ def test_errback_called_by_group_fail_first(self, manager, subtests): redis_connection = get_redis_connection() errback_msg = str(uuid.uuid4()).encode() - errback = redis_echo.si(errback_msg) + redis_key = str(uuid.uuid4()) + errback = redis_echo.si(errback_msg, redis_key=redis_key) group_sig = group(fail.s(), identity.si(42)) group_sig.link_error(errback) - redis_connection.delete("redis-echo") + redis_connection.delete(redis_key) with subtests.test(msg="Error propagates from group"): res = group_sig.delay() with pytest.raises(ExpectedException): res.get(timeout=TIMEOUT) with subtests.test(msg="Errback is called after group task fails"): - maybe_key_msg = redis_connection.blpop("redis-echo", TIMEOUT) - if maybe_key_msg is None: - raise TimeoutError("Errback was not called in time") - _, msg = maybe_key_msg - assert msg == errback_msg + await_redis_echo({errback_msg, }, redis_key=redis_key) + redis_connection.delete(redis_key) def test_errback_called_by_group_fail_last(self, manager, subtests): if not manager.app.conf.result_backend.startswith("redis"): @@ -861,21 +1013,19 @@ def test_errback_called_by_group_fail_last(self, manager, subtests): redis_connection = get_redis_connection() errback_msg = str(uuid.uuid4()).encode() - errback = redis_echo.si(errback_msg) + redis_key = str(uuid.uuid4()) + errback = redis_echo.si(errback_msg, redis_key=redis_key) group_sig = group(identity.si(42), fail.s()) group_sig.link_error(errback) - redis_connection.delete("redis-echo") + redis_connection.delete(redis_key) with subtests.test(msg="Error propagates from group"): res = group_sig.delay() with pytest.raises(ExpectedException): res.get(timeout=TIMEOUT) with subtests.test(msg="Errback is called after group task fails"): - maybe_key_msg = redis_connection.blpop("redis-echo", TIMEOUT) - if maybe_key_msg is None: - raise TimeoutError("Errback was not called in time") - _, msg = maybe_key_msg - assert msg == errback_msg + await_redis_echo({errback_msg, }, redis_key=redis_key) + redis_connection.delete(redis_key) def test_errback_called_by_group_fail_multiple(self, manager, subtests): if not manager.app.conf.result_backend.startswith("redis"): @@ -883,7 +1033,8 @@ def test_errback_called_by_group_fail_multiple(self, manager, subtests): redis_connection = get_redis_connection() expected_errback_count = 42 - errback = redis_count.si() + redis_key = str(uuid.uuid4()) + errback = redis_count.si(redis_key=redis_key) # Include a mix of passing and failing tasks group_sig = group( @@ -891,29 +1042,155 @@ def test_errback_called_by_group_fail_multiple(self, manager, subtests): *(fail.s() for _ in range(expected_errback_count)), ) group_sig.link_error(errback) - redis_connection.delete("redis-count") + + redis_connection.delete(redis_key) with subtests.test(msg="Error propagates from group"): res = group_sig.delay() with pytest.raises(ExpectedException): res.get(timeout=TIMEOUT) with subtests.test(msg="Errback is called after group task fails"): - check_interval = 0.1 - check_max = int(TIMEOUT * check_interval) - for i in range(check_max + 1): - maybe_count = redis_connection.get("redis-count") - # It's either `None` or a base-10 integer - count = int(maybe_count or b"0") - if count == expected_errback_count: - # escape and pass - break - elif i < check_max: - # try again later - sleep(check_interval) - else: - # fail - assert count == expected_errback_count - else: - raise TimeoutError("Errbacks were not called in time") + await_redis_count(expected_errback_count, redis_key=redis_key) + redis_connection.delete(redis_key) + + def test_group_children_with_callbacks(self, manager, subtests): + if not manager.app.conf.result_backend.startswith("redis"): + raise pytest.skip("Requires redis result backend.") + redis_connection = get_redis_connection() + + redis_key = str(uuid.uuid4()) + callback = redis_count.si(redis_key=redis_key) + + child_task_count = 42 + child_sig = identity.si(1337) + child_sig.link(callback) + group_sig = group(child_sig for _ in range(child_task_count)) + + redis_connection.delete(redis_key) + with subtests.test(msg="Chain executes as expected"): + res_obj = group_sig() + assert res_obj.get(timeout=TIMEOUT) == [1337] * child_task_count + with subtests.test(msg="Chain child task callbacks are called"): + await_redis_count(child_task_count, redis_key=redis_key) + redis_connection.delete(redis_key) + + def test_group_children_with_errbacks(self, manager, subtests): + if not manager.app.conf.result_backend.startswith("redis"): + raise pytest.skip("Requires redis result backend.") + redis_connection = get_redis_connection() + + redis_key = str(uuid.uuid4()) + errback = redis_count.si(redis_key=redis_key) + + child_task_count = 42 + child_sig = fail.si() + child_sig.link_error(errback) + group_sig = group(child_sig for _ in range(child_task_count)) + + redis_connection.delete(redis_key) + with subtests.test(msg="Chain fails due to a child task dying"): + res_obj = group_sig() + with pytest.raises(ExpectedException): + res_obj.get(timeout=TIMEOUT) + with subtests.test(msg="Chain child task errbacks are called"): + await_redis_count(child_task_count, redis_key=redis_key) + redis_connection.delete(redis_key) + + def test_group_with_callback_child_replaced(self, manager, subtests): + if not manager.app.conf.result_backend.startswith("redis"): + raise pytest.skip("Requires redis result backend.") + redis_connection = get_redis_connection() + + redis_key = str(uuid.uuid4()) + callback = redis_count.si(redis_key=redis_key) + + group_sig = group(add_replaced.si(42, 1337), identity.si(31337)) + group_sig.link(callback) + + redis_connection.delete(redis_key) + with subtests.test(msg="Chain executes as expected"): + res_obj = group_sig() + assert res_obj.get(timeout=TIMEOUT) == [42 + 1337, 31337] + with subtests.test(msg="Callback is called after group finishes"): + await_redis_count(1, redis_key=redis_key) + redis_connection.delete(redis_key) + + def test_group_with_errback_child_replaced(self, manager, subtests): + if not manager.app.conf.result_backend.startswith("redis"): + raise pytest.skip("Requires redis result backend.") + redis_connection = get_redis_connection() + + redis_key = str(uuid.uuid4()) + errback = redis_count.si(redis_key=redis_key) + + group_sig = group(add_replaced.si(42, 1337), fail.s()) + group_sig.link_error(errback) + + redis_connection.delete(redis_key) + with subtests.test(msg="Chain executes as expected"): + res_obj = group_sig() + with pytest.raises(ExpectedException): + res_obj.get(timeout=TIMEOUT) + with subtests.test(msg="Errback is called after group finishes"): + await_redis_count(1, redis_key=redis_key) + redis_connection.delete(redis_key) + + def test_group_child_with_callback_replaced(self, manager, subtests): + if not manager.app.conf.result_backend.startswith("redis"): + raise pytest.skip("Requires redis result backend.") + redis_connection = get_redis_connection() + + redis_key = str(uuid.uuid4()) + callback = redis_count.si(redis_key=redis_key) + + child_sig = add_replaced.si(42, 1337) + child_sig.link(callback) + group_sig = group(child_sig, identity.si(31337)) + + redis_connection.delete(redis_key) + with subtests.test(msg="Chain executes as expected"): + res_obj = group_sig() + assert res_obj.get(timeout=TIMEOUT) == [42 + 1337, 31337] + with subtests.test(msg="Callback is called after group finishes"): + await_redis_count(1, redis_key=redis_key) + redis_connection.delete(redis_key) + + def test_group_child_with_errback_replaced(self, manager, subtests): + if not manager.app.conf.result_backend.startswith("redis"): + raise pytest.skip("Requires redis result backend.") + redis_connection = get_redis_connection() + + redis_key = str(uuid.uuid4()) + errback = redis_count.si(redis_key=redis_key) + + child_sig = fail_replaced.si() + child_sig.link_error(errback) + group_sig = group(child_sig, identity.si(42)) + + redis_connection.delete(redis_key) + with subtests.test(msg="Chain executes as expected"): + res_obj = group_sig() + with pytest.raises(ExpectedException): + res_obj.get(timeout=TIMEOUT) + with subtests.test(msg="Errback is called after group finishes"): + await_redis_count(1, redis_key=redis_key) + redis_connection.delete(redis_key) + + def test_group_child_replaced_with_chain_first(self): + orig_sig = group(replace_with_chain.si(42), identity.s(1337)) + res_obj = orig_sig.delay() + assert res_obj.get(timeout=TIMEOUT) == [42, 1337] + + def test_group_child_replaced_with_chain_middle(self): + orig_sig = group( + identity.s(42), replace_with_chain.s(1337), identity.s(31337) + ) + res_obj = orig_sig.delay() + assert res_obj.get(timeout=TIMEOUT) == [42, 1337, 31337] + + def test_group_child_replaced_with_chain_last(self): + orig_sig = group(identity.s(42), replace_with_chain.s(1337)) + res_obj = orig_sig.delay() + assert res_obj.get(timeout=TIMEOUT) == [42, 1337] def assert_ids(r, expected_value, expected_root_id, expected_parent_id): @@ -1537,40 +1814,34 @@ def test_errback_called_by_chord_from_simple(self, manager, subtests): redis_connection = get_redis_connection() errback_msg = str(uuid.uuid4()).encode() - errback = redis_echo.si(errback_msg) + redis_key = str(uuid.uuid4()) + errback = redis_echo.si(errback_msg, redis_key=redis_key) child_sig = fail.s() chord_sig = chord((child_sig, ), identity.s()) chord_sig.link_error(errback) + redis_connection.delete(redis_key) with subtests.test(msg="Error propagates from simple header task"): - redis_connection.delete("redis-echo") res = chord_sig.delay() with pytest.raises(ExpectedException): res.get(timeout=TIMEOUT) with subtests.test( msg="Errback is called after simple header task fails" ): - maybe_key_msg = redis_connection.blpop("redis-echo", TIMEOUT) - if maybe_key_msg is None: - raise TimeoutError("Errback was not called in time") - _, msg = maybe_key_msg - assert msg == errback_msg + await_redis_echo({errback_msg, }, redis_key=redis_key) chord_sig = chord((identity.si(42), ), child_sig) chord_sig.link_error(errback) + redis_connection.delete(redis_key) with subtests.test(msg="Error propagates from simple body task"): - redis_connection.delete("redis-echo") res = chord_sig.delay() with pytest.raises(ExpectedException): res.get(timeout=TIMEOUT) with subtests.test( msg="Errback is called after simple body task fails" ): - maybe_key_msg = redis_connection.blpop("redis-echo", TIMEOUT) - if maybe_key_msg is None: - raise TimeoutError("Errback was not called in time") - _, msg = maybe_key_msg - assert msg == errback_msg + await_redis_echo({errback_msg, }, redis_key=redis_key) + redis_connection.delete(redis_key) def test_error_propagates_to_chord_from_chain(self, manager, subtests): try: @@ -1602,44 +1873,38 @@ def test_errback_called_by_chord_from_chain(self, manager, subtests): redis_connection = get_redis_connection() errback_msg = str(uuid.uuid4()).encode() - errback = redis_echo.si(errback_msg) + redis_key = str(uuid.uuid4()) + errback = redis_echo.si(errback_msg, redis_key=redis_key) child_sig = chain(identity.si(42), fail.s(), identity.si(42)) chord_sig = chord((child_sig, ), identity.s()) chord_sig.link_error(errback) + redis_connection.delete(redis_key) with subtests.test( msg="Error propagates from header chain which fails before the end" ): - redis_connection.delete("redis-echo") res = chord_sig.delay() with pytest.raises(ExpectedException): res.get(timeout=TIMEOUT) with subtests.test( msg="Errback is called after header chain which fails before the end" ): - maybe_key_msg = redis_connection.blpop("redis-echo", TIMEOUT) - if maybe_key_msg is None: - raise TimeoutError("Errback was not called in time") - _, msg = maybe_key_msg - assert msg == errback_msg + await_redis_echo({errback_msg, }, redis_key=redis_key) chord_sig = chord((identity.si(42), ), child_sig) chord_sig.link_error(errback) + redis_connection.delete(redis_key) with subtests.test( msg="Error propagates from body chain which fails before the end" ): - redis_connection.delete("redis-echo") res = chord_sig.delay() with pytest.raises(ExpectedException): res.get(timeout=TIMEOUT) with subtests.test( msg="Errback is called after body chain which fails before the end" ): - maybe_key_msg = redis_connection.blpop("redis-echo", TIMEOUT) - if maybe_key_msg is None: - raise TimeoutError("Errback was not called in time") - _, msg = maybe_key_msg - assert msg == errback_msg + await_redis_echo({errback_msg, }, redis_key=redis_key) + redis_connection.delete(redis_key) def test_error_propagates_to_chord_from_chain_tail(self, manager, subtests): try: @@ -1671,44 +1936,38 @@ def test_errback_called_by_chord_from_chain_tail(self, manager, subtests): redis_connection = get_redis_connection() errback_msg = str(uuid.uuid4()).encode() - errback = redis_echo.si(errback_msg) + redis_key = str(uuid.uuid4()) + errback = redis_echo.si(errback_msg, redis_key=redis_key) child_sig = chain(identity.si(42), fail.s()) chord_sig = chord((child_sig, ), identity.s()) chord_sig.link_error(errback) + redis_connection.delete(redis_key) with subtests.test( msg="Error propagates from header chain which fails at the end" ): - redis_connection.delete("redis-echo") res = chord_sig.delay() with pytest.raises(ExpectedException): res.get(timeout=TIMEOUT) with subtests.test( msg="Errback is called after header chain which fails at the end" ): - maybe_key_msg = redis_connection.blpop("redis-echo", TIMEOUT) - if maybe_key_msg is None: - raise TimeoutError("Errback was not called in time") - _, msg = maybe_key_msg - assert msg == errback_msg + await_redis_echo({errback_msg, }, redis_key=redis_key) chord_sig = chord((identity.si(42), ), child_sig) chord_sig.link_error(errback) + redis_connection.delete(redis_key) with subtests.test( msg="Error propagates from body chain which fails at the end" ): - redis_connection.delete("redis-echo") res = chord_sig.delay() with pytest.raises(ExpectedException): res.get(timeout=TIMEOUT) with subtests.test( msg="Errback is called after body chain which fails at the end" ): - maybe_key_msg = redis_connection.blpop("redis-echo", TIMEOUT) - if maybe_key_msg is None: - raise TimeoutError("Errback was not called in time") - _, msg = maybe_key_msg - assert msg == errback_msg + await_redis_echo({errback_msg, }, redis_key=redis_key) + redis_connection.delete(redis_key) def test_error_propagates_to_chord_from_group(self, manager, subtests): try: @@ -1736,36 +1995,30 @@ def test_errback_called_by_chord_from_group(self, manager, subtests): redis_connection = get_redis_connection() errback_msg = str(uuid.uuid4()).encode() - errback = redis_echo.si(errback_msg) + redis_key = str(uuid.uuid4()) + errback = redis_echo.si(errback_msg, redis_key=redis_key) child_sig = group(identity.si(42), fail.s()) chord_sig = chord((child_sig, ), identity.s()) chord_sig.link_error(errback) + redis_connection.delete(redis_key) with subtests.test(msg="Error propagates from header group"): - redis_connection.delete("redis-echo") res = chord_sig.delay() with pytest.raises(ExpectedException): res.get(timeout=TIMEOUT) with subtests.test(msg="Errback is called after header group fails"): - maybe_key_msg = redis_connection.blpop("redis-echo", TIMEOUT) - if maybe_key_msg is None: - raise TimeoutError("Errback was not called in time") - _, msg = maybe_key_msg - assert msg == errback_msg + await_redis_echo({errback_msg, }, redis_key=redis_key) chord_sig = chord((identity.si(42), ), child_sig) chord_sig.link_error(errback) + redis_connection.delete(redis_key) with subtests.test(msg="Error propagates from body group"): - redis_connection.delete("redis-echo") res = chord_sig.delay() with pytest.raises(ExpectedException): res.get(timeout=TIMEOUT) with subtests.test(msg="Errback is called after body group fails"): - maybe_key_msg = redis_connection.blpop("redis-echo", TIMEOUT) - if maybe_key_msg is None: - raise TimeoutError("Errback was not called in time") - _, msg = maybe_key_msg - assert msg == errback_msg + await_redis_echo({errback_msg, }, redis_key=redis_key) + redis_connection.delete(redis_key) def test_errback_called_by_chord_from_group_fail_multiple( self, manager, subtests @@ -1775,7 +2028,8 @@ def test_errback_called_by_chord_from_group_fail_multiple( redis_connection = get_redis_connection() fail_task_count = 42 - errback = redis_count.si() + redis_key = str(uuid.uuid4()) + errback = redis_count.si(redis_key=redis_key) # Include a mix of passing and failing tasks child_sig = group( *(identity.si(42) for _ in range(24)), # arbitrary task count @@ -1784,61 +2038,133 @@ def test_errback_called_by_chord_from_group_fail_multiple( chord_sig = chord((child_sig, ), identity.s()) chord_sig.link_error(errback) + redis_connection.delete(redis_key) with subtests.test(msg="Error propagates from header group"): - redis_connection.delete("redis-count") + redis_connection.delete(redis_key) res = chord_sig.delay() with pytest.raises(ExpectedException): res.get(timeout=TIMEOUT) with subtests.test(msg="Errback is called after header group fails"): # NOTE: Here we only expect the errback to be called once since it # is attached to the chord body which is a single task! - expected_errback_count = 1 - check_interval = 0.1 - check_max = int(TIMEOUT * check_interval) - for i in range(check_max + 1): - maybe_count = redis_connection.get("redis-count") - # It's either `None` or a base-10 integer - count = int(maybe_count or b"0") - if count == expected_errback_count: - # escape and pass - break - elif i < check_max: - # try again later - sleep(check_interval) - else: - # fail - assert count == expected_errback_count - else: - raise TimeoutError("Errbacks were not called in time") + await_redis_count(1, redis_key=redis_key) chord_sig = chord((identity.si(42), ), child_sig) chord_sig.link_error(errback) + redis_connection.delete(redis_key) with subtests.test(msg="Error propagates from body group"): - redis_connection.delete("redis-count") res = chord_sig.delay() with pytest.raises(ExpectedException): res.get(timeout=TIMEOUT) with subtests.test(msg="Errback is called after body group fails"): # NOTE: Here we expect the errback to be called once per failing # task in the chord body since it is a group - expected_errback_count = fail_task_count - check_interval = 0.1 - check_max = int(TIMEOUT * check_interval) - for i in range(check_max + 1): - maybe_count = redis_connection.get("redis-count") - # It's either `None` or a base-10 integer - count = int(maybe_count or b"0") - if count == expected_errback_count: - # escape and pass - break - elif i < check_max: - # try again later - sleep(check_interval) - else: - # fail - assert count == expected_errback_count - else: - raise TimeoutError("Errbacks were not called in time") + await_redis_count(fail_task_count, redis_key=redis_key) + redis_connection.delete(redis_key) + + def test_chord_header_task_replaced_with_chain(self, manager): + try: + manager.app.backend.ensure_chords_allowed() + except NotImplementedError as e: + raise pytest.skip(e.args[0]) + + orig_sig = chord( + replace_with_chain.si(42), + identity.s(), + ) + res_obj = orig_sig.delay() + assert res_obj.get(timeout=TIMEOUT) == [42] + + def test_chord_header_child_replaced_with_chain_first(self, manager): + try: + manager.app.backend.ensure_chords_allowed() + except NotImplementedError as e: + raise pytest.skip(e.args[0]) + + orig_sig = chord( + (replace_with_chain.si(42), identity.s(1337), ), + identity.s(), + ) + res_obj = orig_sig.delay() + assert res_obj.get(timeout=TIMEOUT) == [42, 1337] + + def test_chord_header_child_replaced_with_chain_middle(self, manager): + try: + manager.app.backend.ensure_chords_allowed() + except NotImplementedError as e: + raise pytest.skip(e.args[0]) + + orig_sig = chord( + (identity.s(42), replace_with_chain.s(1337), identity.s(31337), ), + identity.s(), + ) + res_obj = orig_sig.delay() + assert res_obj.get(timeout=TIMEOUT) == [42, 1337, 31337] + + def test_chord_header_child_replaced_with_chain_last(self, manager): + try: + manager.app.backend.ensure_chords_allowed() + except NotImplementedError as e: + raise pytest.skip(e.args[0]) + + orig_sig = chord( + (identity.s(42), replace_with_chain.s(1337), ), + identity.s(), + ) + res_obj = orig_sig.delay() + assert res_obj.get(timeout=TIMEOUT) == [42, 1337] + + def test_chord_body_task_replaced_with_chain(self, manager): + try: + manager.app.backend.ensure_chords_allowed() + except NotImplementedError as e: + raise pytest.skip(e.args[0]) + + orig_sig = chord( + identity.s(42), + replace_with_chain.s(), + ) + res_obj = orig_sig.delay() + assert res_obj.get(timeout=TIMEOUT) == [42] + + def test_chord_body_chain_child_replaced_with_chain_first(self, manager): + try: + manager.app.backend.ensure_chords_allowed() + except NotImplementedError as e: + raise pytest.skip(e.args[0]) + + orig_sig = chord( + identity.s(42), + chain(replace_with_chain.s(), identity.s(), ), + ) + res_obj = orig_sig.delay() + assert res_obj.get(timeout=TIMEOUT) == [42] + + def test_chord_body_chain_child_replaced_with_chain_middle(self, manager): + try: + manager.app.backend.ensure_chords_allowed() + except NotImplementedError as e: + raise pytest.skip(e.args[0]) + + orig_sig = chord( + identity.s(42), + chain(identity.s(), replace_with_chain.s(), identity.s(), ), + ) + res_obj = orig_sig.delay() + assert res_obj.get(timeout=TIMEOUT) == [42] + + def test_chord_body_chain_child_replaced_with_chain_last(self, manager): + try: + manager.app.backend.ensure_chords_allowed() + except NotImplementedError as e: + raise pytest.skip(e.args[0]) + + orig_sig = chord( + identity.s(42), + chain(identity.s(), replace_with_chain.s(), ), + ) + res_obj = orig_sig.delay() + assert res_obj.get(timeout=TIMEOUT) == [42] class test_signature_serialization: diff --git a/t/unit/tasks/test_canvas.py b/t/unit/tasks/test_canvas.py index 7527f0aed24..1b6064f0db5 100644 --- a/t/unit/tasks/test_canvas.py +++ b/t/unit/tasks/test_canvas.py @@ -854,7 +854,7 @@ def test_apply_contains_chords_containing_empty_chain(self): # This is an invalid setup because we can't complete a chord header if # there are no actual tasks which will run in it. However, the current # behaviour of an `IndexError` isn't particularly helpful to a user. - res_obj = group_sig.apply_async() + group_sig.apply_async() def test_apply_contains_chords_containing_chain_with_empty_tail(self): ggchild_count = 42 diff --git a/t/unit/tasks/test_tasks.py b/t/unit/tasks/test_tasks.py index ff6f0049c04..fddeae429bf 100644 --- a/t/unit/tasks/test_tasks.py +++ b/t/unit/tasks/test_tasks.py @@ -1,7 +1,7 @@ import socket import tempfile from datetime import datetime, timedelta -from unittest.mock import ANY, MagicMock, Mock, patch +from unittest.mock import ANY, MagicMock, Mock, call, patch, sentinel import pytest from case import ContextMock @@ -992,10 +992,12 @@ def test_send_event(self): retry=True, retry_policy=self.app.conf.task_publish_retry_policy) def test_replace(self): - sig1 = Mock(name='sig1') + sig1 = MagicMock(name='sig1') sig1.options = {} + self.mytask.request.id = sentinel.request_id with pytest.raises(Ignore): self.mytask.replace(sig1) + sig1.freeze.assert_called_once_with(self.mytask.request.id) def test_replace_with_chord(self): sig1 = Mock(name='sig1') @@ -1003,7 +1005,6 @@ def test_replace_with_chord(self): with pytest.raises(ImproperlyConfigured): self.mytask.replace(sig1) - @pytest.mark.usefixtures('depends_on_current_app') def test_replace_callback(self): c = group([self.mytask.s()], app=self.app) c.freeze = Mock(name='freeze') @@ -1011,29 +1012,23 @@ def test_replace_callback(self): self.mytask.request.id = 'id' self.mytask.request.group = 'group' self.mytask.request.root_id = 'root_id' - self.mytask.request.callbacks = 'callbacks' - self.mytask.request.errbacks = 'errbacks' - - class JsonMagicMock(MagicMock): - parent = None - - def __json__(self): - return 'whatever' - - def reprcall(self, *args, **kwargs): - return 'whatever2' - - mocked_signature = JsonMagicMock(name='s') - accumulate_mock = JsonMagicMock(name='accumulate', s=mocked_signature) - self.mytask.app.tasks['celery.accumulate'] = accumulate_mock - - try: - self.mytask.replace(c) - except Ignore: - mocked_signature.return_value.set.assert_called_with( - link='callbacks', - link_error='errbacks', - ) + self.mytask.request.callbacks = callbacks = 'callbacks' + self.mytask.request.errbacks = errbacks = 'errbacks' + + # Replacement groups get uplifted to chords so that we can accumulate + # the results and link call/errbacks - patch the appropriate `chord` + # methods so we can validate this behaviour + with patch( + "celery.canvas.chord.link" + ) as mock_chord_link, patch( + "celery.canvas.chord.link_error" + ) as mock_chord_link_error: + with pytest.raises(Ignore): + self.mytask.replace(c) + # Confirm that the call/errbacks on the original signature are linked + # to the replacement signature as expected + mock_chord_link.assert_called_once_with(callbacks) + mock_chord_link_error.assert_called_once_with(errbacks) def test_replace_group(self): c = group([self.mytask.s()], app=self.app)