diff --git a/t/integration/test_canvas.py b/t/integration/test_canvas.py index ffb1de27687..cc88050092a 100644 --- a/t/integration/test_canvas.py +++ b/t/integration/test_canvas.py @@ -3203,3 +3203,48 @@ def on_signature(self, sig, **headers) -> dict: stamped_task.stamp(visitor=CustomStampingVisitor()) stamped_task.apply_async().get() assert assertion_result + + def test_all_tasks_of_canvas_are_stamped(self, manager, subtests): + """ Test that complex canvas are stamped correctly """ + try: + manager.app.backend.ensure_chords_allowed() + except NotImplementedError as e: + raise pytest.skip(e.args[0]) + + @task_received.connect + def task_received_handler(**kwargs): + request = kwargs['request'] + nonlocal assertion_result + + assertion_result = all([ + assertion_result, + all([stamped_header in request.stamps for stamped_header in request.stamped_headers]), + request.stamps['stamp'] == 42 + ]) + + # Using a list because pytest.mark.parametrize does not play well + canvas = [ + add.s(1, 1), + group(add.s(1, 1), add.s(2, 2)), + chain(add.s(1, 1), add.s(2, 2)), + chord([add.s(1, 1), add.s(2, 2)], xsum.s()), + chain(group(add.s(0, 0)), add.s(-1)), + add.s(1, 1) | add.s(10), + group(add.s(1, 1) | add.s(10), add.s(2, 2) | add.s(20)), + chain(add.s(1, 1) | add.s(10), add.s(2) | add.s(20)), + chord([add.s(1, 1) | add.s(10), add.s(2, 2) | add.s(20)], xsum.s()), + chain(chain(add.s(1, 1) | add.s(10), add.s(2) | add.s(20)), add.s(3) | add.s(30)), + chord(group(chain(add.s(1, 1), add.s(2)), chord([add.s(3, 3), add.s(4, 4)], xsum.s())), xsum.s()), + ] + + for sig in canvas: + with subtests.test(msg='Assert all tasks are stamped'): + class CustomStampingVisitor(StampingVisitor): + def on_signature(self, sig, **headers) -> dict: + return {'stamp': 42} + + stamped_task = sig + stamped_task.stamp(visitor=CustomStampingVisitor()) + assertion_result = True + stamped_task.apply_async().get() + assert assertion_result