Skip to content

Commit

Permalink
test: Add more tests for from_dict() variants
Browse files Browse the repository at this point in the history
Notably, this exposed the bug tracked in celery#6341 where groups are not
deeply deserialized by `group.from_dict()`.
  • Loading branch information
maybe-sybr authored and jeyrce committed Aug 25, 2021
1 parent 6bb21d2 commit 3423e25
Show file tree
Hide file tree
Showing 3 changed files with 306 additions and 1 deletion.
66 changes: 65 additions & 1 deletion t/integration/tasks.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from time import sleep

from celery import Task, chain, chord, group, shared_task
from celery import Signature, Task, chain, chord, group, shared_task
from celery.exceptions import SoftTimeLimitExceeded
from celery.utils.log import get_task_logger

Expand Down Expand Up @@ -244,3 +244,67 @@ def run(self):
if self.request.retries:
return self.request.retries
raise ValueError()


# The signatures returned by these tasks wouldn't actually run because the
# arguments wouldn't be fulfilled - we never actually delay them so it's fine
@shared_task
def return_nested_signature_chain_chain():
return chain(chain([add.s()]))


@shared_task
def return_nested_signature_chain_group():
return chain(group([add.s()]))


@shared_task
def return_nested_signature_chain_chord():
return chain(chord([add.s()], add.s()))


@shared_task
def return_nested_signature_group_chain():
return group(chain([add.s()]))


@shared_task
def return_nested_signature_group_group():
return group(group([add.s()]))


@shared_task
def return_nested_signature_group_chord():
return group(chord([add.s()], add.s()))


@shared_task
def return_nested_signature_chord_chain():
return chord(chain([add.s()]), add.s())


@shared_task
def return_nested_signature_chord_group():
return chord(group([add.s()]), add.s())


@shared_task
def return_nested_signature_chord_chord():
return chord(chord([add.s()], add.s()), add.s())


@shared_task
def rebuild_signature(sig_dict):
sig_obj = Signature.from_dict(sig_dict)

def _recurse(sig):
if not isinstance(sig, Signature):
raise TypeError("{!r} is not a signature object".format(sig))
# Most canvas types have a `tasks` attribute
if isinstance(sig, (chain, group, chord)):
for task in sig.tasks:
_recurse(task)
# `chord`s also have a `body` attribute
if isinstance(sig, chord):
_recurse(sig.body)
_recurse(sig_obj)
102 changes: 102 additions & 0 deletions t/integration/test_canvas.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from celery.exceptions import TimeoutError
from celery.result import AsyncResult, GroupResult, ResultSet

from . import tasks
from .conftest import get_active_redis_channels, get_redis_connection
from .tasks import (ExpectedException, add, add_chord_to_chord, add_replaced,
add_to_all, add_to_all_to_chord, build_chain_inside_task,
Expand Down Expand Up @@ -1095,3 +1096,104 @@ def test_nested_chord_group_chain_group_tail(self, manager):
)
res = sig.delay()
assert res.get(timeout=TIMEOUT) == [[42, 42]]


class test_signature_serialization:
"""
Confirm nested signatures can be rebuilt after passing through a backend.
These tests are expected to finish and return `None` or raise an exception
in the error case. The exception indicates that some element of a nested
signature object was not properly deserialized from its dictionary
representation, and would explode later on if it were used as a signature.
"""
def test_rebuild_nested_chain_chain(self, manager):
sig = chain(
tasks.return_nested_signature_chain_chain.s(),
tasks.rebuild_signature.s()
)
sig.delay().get(timeout=TIMEOUT)

def test_rebuild_nested_chain_group(self, manager):
sig = chain(
tasks.return_nested_signature_chain_group.s(),
tasks.rebuild_signature.s()
)
sig.delay().get(timeout=TIMEOUT)

def test_rebuild_nested_chain_chord(self, manager):
try:
manager.app.backend.ensure_chords_allowed()
except NotImplementedError as e:
raise pytest.skip(e.args[0])

sig = chain(
tasks.return_nested_signature_chain_chord.s(),
tasks.rebuild_signature.s()
)
sig.delay().get(timeout=TIMEOUT)

@pytest.mark.xfail(reason="#6341")
def test_rebuild_nested_group_chain(self, manager):
sig = chain(
tasks.return_nested_signature_group_chain.s(),
tasks.rebuild_signature.s()
)
sig.delay().get(timeout=TIMEOUT)

@pytest.mark.xfail(reason="#6341")
def test_rebuild_nested_group_group(self, manager):
sig = chain(
tasks.return_nested_signature_group_group.s(),
tasks.rebuild_signature.s()
)
sig.delay().get(timeout=TIMEOUT)

@pytest.mark.xfail(reason="#6341")
def test_rebuild_nested_group_chord(self, manager):
try:
manager.app.backend.ensure_chords_allowed()
except NotImplementedError as e:
raise pytest.skip(e.args[0])

sig = chain(
tasks.return_nested_signature_group_chord.s(),
tasks.rebuild_signature.s()
)
sig.delay().get(timeout=TIMEOUT)

def test_rebuild_nested_chord_chain(self, manager):
try:
manager.app.backend.ensure_chords_allowed()
except NotImplementedError as e:
raise pytest.skip(e.args[0])

sig = chain(
tasks.return_nested_signature_chord_chain.s(),
tasks.rebuild_signature.s()
)
sig.delay().get(timeout=TIMEOUT)

def test_rebuild_nested_chord_group(self, manager):
try:
manager.app.backend.ensure_chords_allowed()
except NotImplementedError as e:
raise pytest.skip(e.args[0])

sig = chain(
tasks.return_nested_signature_chord_group.s(),
tasks.rebuild_signature.s()
)
sig.delay().get(timeout=TIMEOUT)

def test_rebuild_nested_chord_chord(self, manager):
try:
manager.app.backend.ensure_chords_allowed()
except NotImplementedError as e:
raise pytest.skip(e.args[0])

sig = chain(
tasks.return_nested_signature_chord_chord.s(),
tasks.rebuild_signature.s()
)
sig.delay().get(timeout=TIMEOUT)
139 changes: 139 additions & 0 deletions t/unit/tasks/test_canvas.py
Original file line number Diff line number Diff line change
Expand Up @@ -694,6 +694,32 @@ def test_from_dict(self):
x['args'] = None
assert group.from_dict(dict(x))

@pytest.mark.xfail(reason="#6341")
def test_from_dict_deep_deserialize(self):
original_group = group([self.add.s(1, 2)] * 42)
serialized_group = json.loads(json.dumps(original_group))
deserialized_group = group.from_dict(serialized_group)
assert all(
isinstance(child_task, Signature)
for child_task in deserialized_group.tasks
)

@pytest.mark.xfail(reason="#6341")
def test_from_dict_deeper_deserialize(self):
inner_group = group([self.add.s(1, 2)] * 42)
outer_group = group([inner_group] * 42)
serialized_group = json.loads(json.dumps(outer_group))
deserialized_group = group.from_dict(serialized_group)
assert all(
isinstance(child_task, Signature)
for child_task in deserialized_group.tasks
)
assert all(
isinstance(grandchild_task, Signature)
for child_task in deserialized_group.tasks
for grandchild_task in child_task.tasks
)

def test_call_empty_group(self):
x = group(app=self.app)
assert not len(x())
Expand Down Expand Up @@ -1059,6 +1085,119 @@ def chord_add():
_state.task_join_will_block = fixture_task_join_will_block
result.task_join_will_block = fixture_task_join_will_block

def test_from_dict(self):
header = self.add.s(1, 2)
original_chord = chord(header=header)
rebuilt_chord = chord.from_dict(dict(original_chord))
assert isinstance(rebuilt_chord, chord)

def test_from_dict_with_body(self):
header = body = self.add.s(1, 2)
original_chord = chord(header=header, body=body)
rebuilt_chord = chord.from_dict(dict(original_chord))
assert isinstance(rebuilt_chord, chord)

def test_from_dict_deep_deserialize(self, subtests):
header = body = self.add.s(1, 2)
original_chord = chord(header=header, body=body)
serialized_chord = json.loads(json.dumps(original_chord))
deserialized_chord = chord.from_dict(serialized_chord)
with subtests.test(msg="Verify chord is deserialized"):
assert isinstance(deserialized_chord, chord)
with subtests.test(msg="Validate chord header tasks is deserialized"):
assert all(
isinstance(child_task, Signature)
for child_task in deserialized_chord.tasks
)
with subtests.test(msg="Verify chord body is deserialized"):
assert isinstance(deserialized_chord.body, Signature)

@pytest.mark.xfail(reason="#6341")
def test_from_dict_deep_deserialize_group(self, subtests):
header = body = group([self.add.s(1, 2)] * 42)
original_chord = chord(header=header, body=body)
serialized_chord = json.loads(json.dumps(original_chord))
deserialized_chord = chord.from_dict(serialized_chord)
with subtests.test(msg="Verify chord is deserialized"):
assert isinstance(deserialized_chord, chord)
# A header which is a group gets unpacked into the chord's `tasks`
with subtests.test(
msg="Validate chord header tasks are deserialized and unpacked"
):
assert all(
isinstance(child_task, Signature)
and not isinstance(child_task, group)
for child_task in deserialized_chord.tasks
)
# A body which is a group remains as it we passed in
with subtests.test(
msg="Validate chord body is deserialized and not unpacked"
):
assert isinstance(deserialized_chord.body, group)
assert all(
isinstance(body_child_task, Signature)
for body_child_task in deserialized_chord.body.tasks
)

@pytest.mark.xfail(reason="#6341")
def test_from_dict_deeper_deserialize_group(self, subtests):
inner_group = group([self.add.s(1, 2)] * 42)
header = body = group([inner_group] * 42)
original_chord = chord(header=header, body=body)
serialized_chord = json.loads(json.dumps(original_chord))
deserialized_chord = chord.from_dict(serialized_chord)
with subtests.test(msg="Verify chord is deserialized"):
assert isinstance(deserialized_chord, chord)
# A header which is a group gets unpacked into the chord's `tasks`
with subtests.test(
msg="Validate chord header tasks are deserialized and unpacked"
):
assert all(
isinstance(child_task, group)
for child_task in deserialized_chord.tasks
)
assert all(
isinstance(grandchild_task, Signature)
for child_task in deserialized_chord.tasks
for grandchild_task in child_task.tasks
)
# A body which is a group remains as it we passed in
with subtests.test(
msg="Validate chord body is deserialized and not unpacked"
):
assert isinstance(deserialized_chord.body, group)
assert all(
isinstance(body_child_task, group)
for body_child_task in deserialized_chord.body.tasks
)
assert all(
isinstance(body_grandchild_task, Signature)
for body_child_task in deserialized_chord.body.tasks
for body_grandchild_task in body_child_task.tasks
)

def test_from_dict_deep_deserialize_chain(self, subtests):
header = body = chain([self.add.s(1, 2)] * 42)
original_chord = chord(header=header, body=body)
serialized_chord = json.loads(json.dumps(original_chord))
deserialized_chord = chord.from_dict(serialized_chord)
with subtests.test(msg="Verify chord is deserialized"):
assert isinstance(deserialized_chord, chord)
# A header which is a chain gets unpacked into the chord's `tasks`
with subtests.test(
msg="Validate chord header tasks are deserialized and unpacked"
):
assert all(
isinstance(child_task, Signature)
and not isinstance(child_task, chain)
for child_task in deserialized_chord.tasks
)
# A body which is a chain gets mutatated into the hidden `_chain` class
with subtests.test(
msg="Validate chord body is deserialized and not unpacked"
):
assert isinstance(deserialized_chord.body, _chain)


class test_maybe_signature(CanvasCase):

Expand Down

0 comments on commit 3423e25

Please sign in to comment.