Skip to content

Commit

Permalink
fix: Ensure replacement tasks get the group index
Browse files Browse the repository at this point in the history
This change adds some tests to ensure that when a task is replaced, it
runs as expected. This exposed a bug where the group index of a task
would be lost when replaced with a chain since chains would not pass
their `group_index` option down to the final task when applied. This
manifested as the results of chords being mis-ordered on the redis
backend since the group index would default to `+inf`. Other backends
may have had similar issues.
  • Loading branch information
maybe-sybr committed May 23, 2021
1 parent 9a7fdb7 commit fcfbe43
Show file tree
Hide file tree
Showing 2 changed files with 146 additions and 2 deletions.
5 changes: 3 additions & 2 deletions celery/canvas.py
Original file line number Diff line number Diff line change
Expand Up @@ -642,7 +642,8 @@ def apply_async(self, args=None, kwargs=None, **options):

def run(self, args=None, kwargs=None, group_id=None, chord=None,
task_id=None, link=None, link_error=None, publisher=None,
producer=None, root_id=None, parent_id=None, app=None, **options):
producer=None, root_id=None, parent_id=None, app=None,
group_index=None, **options):
# pylint: disable=redefined-outer-name
# XXX chord is also a class in outer scope.
args = args if args else ()
Expand All @@ -656,7 +657,7 @@ def run(self, args=None, kwargs=None, group_id=None, chord=None,

tasks, results_from_prepare = self.prepare_steps(
args, kwargs, self.tasks, root_id, parent_id, link_error, app,
task_id, group_id, chord,
task_id, group_id, chord, group_index=group_index,
)

if results_from_prepare:
Expand Down
143 changes: 143 additions & 0 deletions t/integration/test_canvas.py
Original file line number Diff line number Diff line change
Expand Up @@ -731,6 +731,28 @@ def test_chain_child_with_errback_replaced(self, manager, subtests):
await_redis_count(1, redis_key=redis_key)
redis_connection.delete(redis_key)

def test_task_replaced_with_chain(self):
orig_sig = replace_with_chain.si(42)
res_obj = orig_sig.delay()
assert res_obj.get(timeout=TIMEOUT) == 42

def test_chain_child_replaced_with_chain_first(self):
orig_sig = chain(replace_with_chain.si(42), identity.s())
res_obj = orig_sig.delay()
assert res_obj.get(timeout=TIMEOUT) == 42

def test_chain_child_replaced_with_chain_middle(self):
orig_sig = chain(
identity.s(42), replace_with_chain.s(), identity.s()
)
res_obj = orig_sig.delay()
assert res_obj.get(timeout=TIMEOUT) == 42

def test_chain_child_replaced_with_chain_last(self):
orig_sig = chain(identity.s(42), replace_with_chain.s())
res_obj = orig_sig.delay()
assert res_obj.get(timeout=TIMEOUT) == 42


class test_result_set:

Expand Down Expand Up @@ -1153,6 +1175,23 @@ def test_group_child_with_errback_replaced(self, manager, subtests):
await_redis_count(1, redis_key=redis_key)
redis_connection.delete(redis_key)

def test_group_child_replaced_with_chain_first(self):
orig_sig = group(replace_with_chain.si(42), identity.s(1337))
res_obj = orig_sig.delay()
assert res_obj.get(timeout=TIMEOUT) == [42, 1337]

def test_group_child_replaced_with_chain_middle(self):
orig_sig = group(
identity.s(42), replace_with_chain.s(1337), identity.s(31337)
)
res_obj = orig_sig.delay()
assert res_obj.get(timeout=TIMEOUT) == [42, 1337, 31337]

def test_group_child_replaced_with_chain_last(self):
orig_sig = group(identity.s(42), replace_with_chain.s(1337))
res_obj = orig_sig.delay()
assert res_obj.get(timeout=TIMEOUT) == [42, 1337]


def assert_ids(r, expected_value, expected_root_id, expected_parent_id):
root_id, parent_id, value = r.get(timeout=TIMEOUT)
Expand Down Expand Up @@ -2023,6 +2062,110 @@ def test_errback_called_by_chord_from_group_fail_multiple(
await_redis_count(fail_task_count, redis_key=redis_key)
redis_connection.delete(redis_key)

def test_chord_header_task_replaced_with_chain(self, manager):
try:
manager.app.backend.ensure_chords_allowed()
except NotImplementedError as e:
raise pytest.skip(e.args[0])

orig_sig = chord(
replace_with_chain.si(42),
identity.s(),
)
res_obj = orig_sig.delay()
assert res_obj.get(timeout=TIMEOUT) == [42]

def test_chord_header_child_replaced_with_chain_first(self, manager):
try:
manager.app.backend.ensure_chords_allowed()
except NotImplementedError as e:
raise pytest.skip(e.args[0])

orig_sig = chord(
(replace_with_chain.si(42), identity.s(1337), ),
identity.s(),
)
res_obj = orig_sig.delay()
assert res_obj.get(timeout=TIMEOUT) == [42, 1337]

def test_chord_header_child_replaced_with_chain_middle(self, manager):
try:
manager.app.backend.ensure_chords_allowed()
except NotImplementedError as e:
raise pytest.skip(e.args[0])

orig_sig = chord(
(identity.s(42), replace_with_chain.s(1337), identity.s(31337), ),
identity.s(),
)
res_obj = orig_sig.delay()
assert res_obj.get(timeout=TIMEOUT) == [42, 1337, 31337]

def test_chord_header_child_replaced_with_chain_last(self, manager):
try:
manager.app.backend.ensure_chords_allowed()
except NotImplementedError as e:
raise pytest.skip(e.args[0])

orig_sig = chord(
(identity.s(42), replace_with_chain.s(1337), ),
identity.s(),
)
res_obj = orig_sig.delay()
assert res_obj.get(timeout=TIMEOUT) == [42, 1337]

def test_chord_body_task_replaced_with_chain(self, manager):
try:
manager.app.backend.ensure_chords_allowed()
except NotImplementedError as e:
raise pytest.skip(e.args[0])

orig_sig = chord(
identity.s(42),
replace_with_chain.s(),
)
res_obj = orig_sig.delay()
assert res_obj.get(timeout=TIMEOUT) == [42]

def test_chord_body_chain_child_replaced_with_chain_first(self, manager):
try:
manager.app.backend.ensure_chords_allowed()
except NotImplementedError as e:
raise pytest.skip(e.args[0])

orig_sig = chord(
identity.s(42),
chain(replace_with_chain.s(), identity.s(), ),
)
res_obj = orig_sig.delay()
assert res_obj.get(timeout=TIMEOUT) == [42]

def test_chord_body_chain_child_replaced_with_chain_middle(self, manager):
try:
manager.app.backend.ensure_chords_allowed()
except NotImplementedError as e:
raise pytest.skip(e.args[0])

orig_sig = chord(
identity.s(42),
chain(identity.s(), replace_with_chain.s(), identity.s(), ),
)
res_obj = orig_sig.delay()
assert res_obj.get(timeout=TIMEOUT) == [42]

def test_chord_body_chain_child_replaced_with_chain_last(self, manager):
try:
manager.app.backend.ensure_chords_allowed()
except NotImplementedError as e:
raise pytest.skip(e.args[0])

orig_sig = chord(
identity.s(42),
chain(identity.s(), replace_with_chain.s(), ),
)
res_obj = orig_sig.delay()
assert res_obj.get(timeout=TIMEOUT) == [42]


class test_signature_serialization:
"""
Expand Down

0 comments on commit fcfbe43

Please sign in to comment.