Skip to content

Commit

Permalink
Eliminate consecutive chords generated by group | task upgrade (#8663)
Browse files Browse the repository at this point in the history
* chord | task -> attach to body in prepare_steps

* add unit test

* fix: clone original chord before modifying its body

* fix: misuse of task clone

* turning chained chords into a single chord with nested bodies

* remove the for-loop and consider the type of the unrolled group

* replace pop with slice

* add integration tests

* add unit test

* updated tests

---------

Co-authored-by: Wang Han <wanghan@airdoc.com>
  • Loading branch information
hann-wang and Wang Han committed Feb 27, 2024
1 parent ac16f23 commit f8c952d
Show file tree
Hide file tree
Showing 3 changed files with 98 additions and 0 deletions.
9 changes: 9 additions & 0 deletions celery/canvas.py
Original file line number Diff line number Diff line change
Expand Up @@ -958,6 +958,8 @@ def __or__(self, other):
if isinstance(other, group):
# unroll group with one member
other = maybe_unroll_group(other)
if not isinstance(other, group):
return self.__or__(other)
# chain | group() -> chain
tasks = self.unchain_tasks()
if not tasks:
Expand All @@ -981,6 +983,13 @@ def __or__(self, other):
sig = self.clone()
sig.tasks[-1] = chord(
sig.tasks[-1], other, app=self._app)
# In the scenario where the second-to-last item in a chain is a chord,
# it leads to a situation where two consecutive chords are formed.
# In such cases, a further upgrade can be considered.
# This would involve chaining the body of the second-to-last chord with the last chord."
if len(sig.tasks) > 1 and isinstance(sig.tasks[-2], chord):
sig.tasks[-2].body = sig.tasks[-2].body | sig.tasks[-1]
sig.tasks = sig.tasks[:-1]
return sig
elif self.tasks and isinstance(self.tasks[-1], chord):
# CHAIN [last item is chord] -> chain with chord body.
Expand Down
59 changes: 59 additions & 0 deletions t/integration/test_canvas.py
Original file line number Diff line number Diff line change
Expand Up @@ -1037,6 +1037,65 @@ def test_freezing_chain_sets_id_of_last_task(self, manager):
c.freeze(last_task.id)
assert c.id == last_task.id

@pytest.mark.parametrize(
"group_last_task",
[False, True],
)
def test_chaining_upgraded_chords_mixed_canvas_protocol_2(
self, manager, subtests, group_last_task):
""" This test is built to reproduce the github issue https://github.com/celery/celery/issues/8662
The issue describes a canvas where a chain of groups are executed multiple times instead of once.
This test is built to reproduce the issue and to verify that the issue is fixed.
"""
try:
manager.app.backend.ensure_chords_allowed()
except NotImplementedError as e:
raise pytest.skip(e.args[0])

if not manager.app.conf.result_backend.startswith('redis'):
raise pytest.skip('Requires redis result backend.')

redis_connection = get_redis_connection()
redis_key = 'echo_chamber'

c = chain(
group([
redis_echo.si('1', redis_key=redis_key),
redis_echo.si('2', redis_key=redis_key)
]),
group([
redis_echo.si('3', redis_key=redis_key),
redis_echo.si('4', redis_key=redis_key),
redis_echo.si('5', redis_key=redis_key)
]),
group([
redis_echo.si('6', redis_key=redis_key),
redis_echo.si('7', redis_key=redis_key),
redis_echo.si('8', redis_key=redis_key),
redis_echo.si('9', redis_key=redis_key)
]),
redis_echo.si('Done', redis_key='Done') if not group_last_task else
group(redis_echo.si('Done', redis_key='Done')),
)

with subtests.test(msg='Run the chain and wait for completion'):
redis_connection.delete(redis_key, 'Done')
c.delay().get(timeout=TIMEOUT)
await_redis_list_message_length(1, redis_key='Done', timeout=10)

with subtests.test(msg='All tasks are executed once'):
actual = [
sig.decode('utf-8')
for sig in redis_connection.lrange(redis_key, 0, -1)
]
expected = [str(i) for i in range(1, 10)]
with subtests.test(msg='All tasks are executed once'):
assert sorted(actual) == sorted(expected)

# Cleanup
redis_connection.delete(redis_key, 'Done')


class test_result_set:

Expand Down
30 changes: 30 additions & 0 deletions t/unit/tasks/test_canvas.py
Original file line number Diff line number Diff line change
Expand Up @@ -571,6 +571,36 @@ def test_chain_of_chord_upgrade_on_chaining(self):
assert isinstance(new_chain, _chain)
assert isinstance(new_chain.tasks[0].body, chord)

@pytest.mark.parametrize(
"group_last_task",
[False, True],
)
def test_chain_of_chord_upgrade_on_chaining__protocol_2(
self, group_last_task):
c = chain(
group([self.add.s(i, i) for i in range(5)], app=self.app),
group([self.add.s(i, i) for i in range(10, 15)], app=self.app),
group([self.add.s(i, i) for i in range(20, 25)], app=self.app),
self.add.s(30) if not group_last_task else group(self.add.s(30),
app=self.app))
assert isinstance(c, _chain)
assert len(
c.tasks
) == 1, "Consecutive chords should be further upgraded to a single chord."
assert isinstance(c.tasks[0], chord)

def test_chain_of_chord_upgrade_on_chaining__protocol_3(self):
c = chain(
chain([self.add.s(i, i) for i in range(5)]),
group([self.add.s(i, i) for i in range(10, 15)], app=self.app),
chord([signature('header')], signature('body'), app=self.app),
group([self.add.s(i, i) for i in range(20, 25)], app=self.app))
assert isinstance(c, _chain)
assert isinstance(
c.tasks[-1], chord
), "Chord followed by a group should be upgraded to a single chord with chained body."
assert len(c.tasks) == 6

def test_apply_options(self):

class static(Signature):
Expand Down

0 comments on commit f8c952d

Please sign in to comment.