From d705a73dd1d8cd75906b7c3e280cd335c5b66b7f Mon Sep 17 00:00:00 2001 From: maybe-sybr <58414429+maybe-sybr@users.noreply.github.com> Date: Mon, 17 May 2021 14:10:29 +1000 Subject: [PATCH] fix: Ensure replacement tasks get the group index This change adds some tests to ensure that when a task is replaced, it runs as expected. This exposed a bug where the group index of a task would be lost when replaced with a chain since chains would not pass their `group_index` option down to the final task when applied. This manifested as the results of chords being mis-ordered on the redis backend since the group index would default to `+inf`. Other backends may have had similar issues. --- celery/canvas.py | 5 +- t/integration/test_canvas.py | 143 +++++++++++++++++++++++++++++++++++ 2 files changed, 146 insertions(+), 2 deletions(-) diff --git a/celery/canvas.py b/celery/canvas.py index 9b32e832fd0..fb9c9640399 100644 --- a/celery/canvas.py +++ b/celery/canvas.py @@ -642,7 +642,8 @@ def apply_async(self, args=None, kwargs=None, **options): def run(self, args=None, kwargs=None, group_id=None, chord=None, task_id=None, link=None, link_error=None, publisher=None, - producer=None, root_id=None, parent_id=None, app=None, **options): + producer=None, root_id=None, parent_id=None, app=None, + group_index=None, **options): # pylint: disable=redefined-outer-name # XXX chord is also a class in outer scope. args = args if args else () @@ -656,7 +657,7 @@ def run(self, args=None, kwargs=None, group_id=None, chord=None, tasks, results_from_prepare = self.prepare_steps( args, kwargs, self.tasks, root_id, parent_id, link_error, app, - task_id, group_id, chord, + task_id, group_id, chord, group_index=group_index, ) if results_from_prepare: diff --git a/t/integration/test_canvas.py b/t/integration/test_canvas.py index 51e6ee2e82d..3d23d387b77 100644 --- a/t/integration/test_canvas.py +++ b/t/integration/test_canvas.py @@ -731,6 +731,28 @@ def test_chain_child_with_errback_replaced(self, manager, subtests): await_redis_count(1, redis_key=redis_key) redis_connection.delete(redis_key) + def test_task_replaced_with_chain(self): + orig_sig = replace_with_chain.si(42) + res_obj = orig_sig.delay() + assert res_obj.get(timeout=TIMEOUT) == 42 + + def test_chain_child_replaced_with_chain_first(self): + orig_sig = chain(replace_with_chain.si(42), identity.s()) + res_obj = orig_sig.delay() + assert res_obj.get(timeout=TIMEOUT) == 42 + + def test_chain_child_replaced_with_chain_middle(self): + orig_sig = chain( + identity.s(42), replace_with_chain.s(), identity.s() + ) + res_obj = orig_sig.delay() + assert res_obj.get(timeout=TIMEOUT) == 42 + + def test_chain_child_replaced_with_chain_last(self): + orig_sig = chain(identity.s(42), replace_with_chain.s()) + res_obj = orig_sig.delay() + assert res_obj.get(timeout=TIMEOUT) == 42 + class test_result_set: @@ -1153,6 +1175,23 @@ def test_group_child_with_errback_replaced(self, manager, subtests): await_redis_count(1, redis_key=redis_key) redis_connection.delete(redis_key) + def test_group_child_replaced_with_chain_first(self): + orig_sig = group(replace_with_chain.si(42), identity.s(1337)) + res_obj = orig_sig.delay() + assert res_obj.get(timeout=TIMEOUT) == [42, 1337] + + def test_group_child_replaced_with_chain_middle(self): + orig_sig = group( + identity.s(42), replace_with_chain.s(1337), identity.s(31337) + ) + res_obj = orig_sig.delay() + assert res_obj.get(timeout=TIMEOUT) == [42, 1337, 31337] + + def test_group_child_replaced_with_chain_last(self): + orig_sig = group(identity.s(42), replace_with_chain.s(1337)) + res_obj = orig_sig.delay() + assert res_obj.get(timeout=TIMEOUT) == [42, 1337] + def assert_ids(r, expected_value, expected_root_id, expected_parent_id): root_id, parent_id, value = r.get(timeout=TIMEOUT) @@ -2023,6 +2062,110 @@ def test_errback_called_by_chord_from_group_fail_multiple( await_redis_count(fail_task_count, redis_key=redis_key) redis_connection.delete(redis_key) + def test_chord_header_task_replaced_with_chain(self): + try: + manager.app.backend.ensure_chords_allowed() + except NotImplementedError as e: + raise pytest.skip(e.args[0]) + + orig_sig = chord( + replace_with_chain.si(42), + identity.s(), + ) + res_obj = orig_sig.delay() + assert res_obj.get(timeout=TIMEOUT) == [42] + + def test_chord_header_child_replaced_with_chain_first(self): + try: + manager.app.backend.ensure_chords_allowed() + except NotImplementedError as e: + raise pytest.skip(e.args[0]) + + orig_sig = chord( + (replace_with_chain.si(42), identity.s(1337), ), + identity.s(), + ) + res_obj = orig_sig.delay() + assert res_obj.get(timeout=TIMEOUT) == [42, 1337] + + def test_chord_header_child_replaced_with_chain_middle(self): + try: + manager.app.backend.ensure_chords_allowed() + except NotImplementedError as e: + raise pytest.skip(e.args[0]) + + orig_sig = chord( + (identity.s(42), replace_with_chain.s(1337), identity.s(31337), ), + identity.s(), + ) + res_obj = orig_sig.delay() + assert res_obj.get(timeout=TIMEOUT) == [42, 1337, 31337] + + def test_chord_header_child_replaced_with_chain_last(self): + try: + manager.app.backend.ensure_chords_allowed() + except NotImplementedError as e: + raise pytest.skip(e.args[0]) + + orig_sig = chord( + (identity.s(42), replace_with_chain.s(1337), ), + identity.s(), + ) + res_obj = orig_sig.delay() + assert res_obj.get(timeout=TIMEOUT) == [42, 1337] + + def test_chord_body_task_replaced_with_chain(self): + try: + manager.app.backend.ensure_chords_allowed() + except NotImplementedError as e: + raise pytest.skip(e.args[0]) + + orig_sig = chord( + identity.s(42), + replace_with_chain.s(), + ) + res_obj = orig_sig.delay() + assert res_obj.get(timeout=TIMEOUT) == [42] + + def test_chord_body_chain_child_replaced_with_chain_first(self): + try: + manager.app.backend.ensure_chords_allowed() + except NotImplementedError as e: + raise pytest.skip(e.args[0]) + + orig_sig = chord( + identity.s(42), + chain(replace_with_chain.s(), identity.s(), ), + ) + res_obj = orig_sig.delay() + assert res_obj.get(timeout=TIMEOUT) == [42] + + def test_chord_body_chain_child_replaced_with_chain_middle(self): + try: + manager.app.backend.ensure_chords_allowed() + except NotImplementedError as e: + raise pytest.skip(e.args[0]) + + orig_sig = chord( + identity.s(42), + chain(identity.s(), replace_with_chain.s(), identity.s(), ), + ) + res_obj = orig_sig.delay() + assert res_obj.get(timeout=TIMEOUT) == [42] + + def test_chord_body_chain_child_replaced_with_chain_last(self): + try: + manager.app.backend.ensure_chords_allowed() + except NotImplementedError as e: + raise pytest.skip(e.args[0]) + + orig_sig = chord( + identity.s(42), + chain(identity.s(), replace_with_chain.s(), ), + ) + res_obj = orig_sig.delay() + assert res_obj.get(timeout=TIMEOUT) == [42] + class test_signature_serialization: """