Skip to content

Commit

Permalink
Fixed bug where linking a stamped task did not add the stamp to the l…
Browse files Browse the repository at this point in the history
…ink's options
  • Loading branch information
Nusnus committed Jan 4, 2023
1 parent 24ac092 commit 9990c7a
Show file tree
Hide file tree
Showing 3 changed files with 161 additions and 0 deletions.
12 changes: 12 additions & 0 deletions celery/canvas.py
Original file line number Diff line number Diff line change
Expand Up @@ -622,6 +622,12 @@ def stamp_links(self, visitor, **headers):
headers.update(visitor_headers or {})
link = maybe_signature(link, app=self.app)
link.stamp(visitor=visitor, **headers)
# Stamping a link to a signature with previous stamps
# may result in missing stamps in the link options, if the linking
# was done AFTER the stamping of the signature
for stamp in link.options['stamped_headers']:
if stamp in self.options and stamp not in link.options:
link.options[stamp] = self.options[stamp]

# Stamp all of the errbacks of this signature
headers = non_visitor_headers.copy()
Expand All @@ -632,6 +638,12 @@ def stamp_links(self, visitor, **headers):
headers.update(visitor_headers or {})
link = maybe_signature(link, app=self.app)
link.stamp(visitor=visitor, **headers)
# Stamping a link to a signature with previous stamps
# may result in missing stamps in the link options, if the linking
# was done AFTER the stamping of the signature
for stamp in link.options['stamped_headers']:
if stamp in self.options and stamp not in link.options:
link.options[stamp] = self.options[stamp]

def _with_list_option(self, key):
"""Gets the value at the given self.options[key] as a list.
Expand Down
75 changes: 75 additions & 0 deletions t/integration/test_canvas.py
Original file line number Diff line number Diff line change
Expand Up @@ -3347,3 +3347,78 @@ def task_received_handler(**kwargs):
gid1 = sig.options['task_id']
sleep(1)
assert assertion_result, 'Group stamping is corrupted'

def test_linking_stamped_sig(self, manager):
""" Test that linking a callback after stamping will stamp the callback correctly"""

assertion_result = False

@task_received.connect
def task_received_handler(
sender=None,
request=None,
signal=None,
**kwargs
):
nonlocal assertion_result
link = request._Request__payload[2]['callbacks'][0]
assertion_result = all([
stamped_header in link['options']
for stamped_header in link['options']['stamped_headers']
])

class FixedMonitoringIdStampingVisitor(StampingVisitor):

def __init__(self, msg_id):
self.msg_id = msg_id

def on_signature(self, sig, **headers):
mtask_id = self.msg_id
return {"mtask_id": mtask_id}

link_sig = identity.si('link_sig')
stamped_pass_sig = identity.si('passing sig')
stamped_pass_sig.stamp(visitor=FixedMonitoringIdStampingVisitor(str(uuid.uuid4())))
stamped_pass_sig.link(link_sig)
# This causes the relevant stamping for this test case
# as it will stamp the link via the group stamping internally
stamped_pass_sig.apply_async().get(timeout=2)
assert assertion_result

def test_err_linking_stamped_sig(self, manager):
""" Test that linking an error after stamping will stamp the errlink correctly"""

assertion_result = False

@task_received.connect
def task_received_handler(
sender=None,
request=None,
signal=None,
**kwargs
):
nonlocal assertion_result
link_error = request.errbacks[0]
assertion_result = all([
stamped_header in link_error['options']
for stamped_header in link_error['options']['stamped_headers']
])

class FixedMonitoringIdStampingVisitor(StampingVisitor):

def __init__(self, msg_id):
self.msg_id = msg_id

def on_signature(self, sig, **headers):
mtask_id = self.msg_id
return {"mtask_id": mtask_id}

link_error_sig = identity.si('link_error')
stamped_fail_sig = fail.si()
stamped_fail_sig.stamp(visitor=FixedMonitoringIdStampingVisitor(str(uuid.uuid4())))
stamped_fail_sig.link_error(link_error_sig)
with pytest.raises(ExpectedException):
# This causes the relevant stamping for this test case
# as it will stamp the link via the group stamping internally
stamped_fail_sig.apply_async().get()
assert assertion_result
74 changes: 74 additions & 0 deletions t/unit/tasks/test_canvas.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,6 +324,80 @@ def on_errback(self, errback, **header) -> dict:
assert headers['on_errback'] is True
assert headers['header'] == 'value'

@pytest.mark.usefixtures('depends_on_current_app')
def test_callback_stamping_link_after_stamp(self, subtests):
self.app.conf.task_always_eager = True
self.app.conf.task_store_eager_result = True
self.app.conf.result_extended = True

class CustomStampingVisitor(StampingVisitor):
def on_signature(self, sig, **headers) -> dict:
return {'header': 'value'}

def on_callback(self, callback, **header) -> dict:
return {'on_callback': True}

def on_errback(self, errback, **header) -> dict:
return {'on_errback': True}

sig_1 = self.add.s(0, 1)
sig_1_res = sig_1.freeze()
group_sig = group([self.add.s(3), self.add.s(4)])
group_sig_res = group_sig.freeze()
chord_sig = chord([self.xsum.s(), self.xsum.s()], self.xsum.s())
chord_sig_res = chord_sig.freeze()
sig_2 = self.add.s(2)
sig_2_res = sig_2.freeze()
chain_sig = chain(
sig_1, # --> 1
group_sig, # --> [1+3, 1+4] --> [4, 5]
chord_sig, # --> [4+5, 4+5] --> [9, 9] --> 9+9 --> 18
sig_2 # --> 18 + 2 --> 20
)
callback = signature('callback_task')
errback = signature('errback_task')
chain_sig.stamp(visitor=CustomStampingVisitor())
chain_sig.link(callback)
chain_sig.link_error(errback)
chain_sig_res = chain_sig.apply_async()
chain_sig_res.get()

with subtests.test("Confirm the chain was executed correctly", result=20):
# Before we run our assersions, let's confirm the base functionality of the chain is working
# as expected including the links stamping.
assert chain_sig_res.result == 20

with subtests.test("sig_1 is stamped with custom visitor", stamped_headers=["header", "groups"]):
assert sorted(sig_1_res._get_task_meta()["stamped_headers"]) == sorted(["header", "groups"])

with subtests.test("group_sig is stamped with custom visitor", stamped_headers=["header", "groups"]):
for result in group_sig_res.results:
assert sorted(result._get_task_meta()["stamped_headers"]) == sorted(["header", "groups"])

with subtests.test("chord_sig is stamped with custom visitor", stamped_headers=["header", "groups"]):
assert sorted(chord_sig_res._get_task_meta()["stamped_headers"]) == sorted(["header", "groups"])

with subtests.test("sig_2 is stamped with custom visitor", stamped_headers=["header", "groups"]):
assert sorted(sig_2_res._get_task_meta()["stamped_headers"]) == sorted(["header", "groups"])

with subtests.test("callback is stamped with custom visitor",
stamped_headers=["header", "groups, on_callback"]):
callback_link = chain_sig.options['link'][0]
headers = callback_link.options
stamped_headers = headers['stamped_headers']
assert 'on_callback' not in stamped_headers, "Linking after stamping should not stamp the callback"
assert sorted(stamped_headers) == sorted(["header", "groups"])
assert headers['header'] == 'value'

with subtests.test("errback is stamped with custom visitor",
stamped_headers=["header", "groups, on_errback"]):
errback_link = chain_sig.options['link_error'][0]
headers = errback_link.options
stamped_headers = headers['stamped_headers']
assert 'on_callback' not in stamped_headers, "Linking after stamping should not stamp the errback"
assert sorted(stamped_headers) == sorted(["header", "groups"])
assert headers['header'] == 'value'

@pytest.mark.usefixtures('depends_on_current_app')
def test_callback_stamping_on_replace(self, subtests):
class CustomStampingVisitor(StampingVisitor):
Expand Down

0 comments on commit 9990c7a

Please sign in to comment.