Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
17 commits
Select commit Hold shift + click to select a range
20d4bf5
fix(signals): support Celery protocol v2 in create_user_task
asajjad2 May 21, 2025
5ca72bb
fix: resolve pylint warnings for variable shadowing
asajjad2 May 30, 2025
d9a4c4a
test: add direct tests for proto2_to_proto1, extract_proto2_headers, …
asajjad2 Jun 24, 2025
6893eb5
test: use pytest monkeypatch fixture instead of mocker
asajjad2 Jun 24, 2025
d4ae02a
test: fix proto2_to_proto1 test to match stringified list output in c…
asajjad2 Jun 24, 2025
a19e78b
test: fix proto2_to_proto1 test to check absence of 'chain' key
asajjad2 Jun 24, 2025
4d49fbf
docs: add docstring to TestUtils to satisfy pylint
asajjad2 Jun 24, 2025
446011d
test: add test to cover Celery protocol v2 path in create_user_task
asajjad2 Jun 24, 2025
850b8d2
test: use unittest.mock.patch for protocol v2 test in Django TestCase
asajjad2 Jun 24, 2025
a7104b5
test: directly test protocol v2 path by calling create_user_task with…
asajjad2 Jun 24, 2025
79cbf14
test: remove unneeded protocol v2 override setting
asajjad2 Jun 24, 2025
6cc71f9
test: fix protocol v2 test to avoid duplicate user_id arg
asajjad2 Jun 24, 2025
7b1efac
test: fix protocol v2 test to avoid duplicate 'argument' arg
asajjad2 Jun 24, 2025
2c0012b
test: set and restore celery_app.conf.task_protocol directly
asajjad2 Jun 24, 2025
2620e1f
chore: remove unused patch import after refactor
asajjad2 Jun 24, 2025
24a3da3
chore: reorder imports in test_signals.py to satisfy isort
asajjad2 Jun 24, 2025
9b0bf42
chore: fix UserTask import order in test_signals.py for isort
asajjad2 Jun 24, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
93 changes: 92 additions & 1 deletion tests/test_signals.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,9 @@

from user_tasks import user_task_stopped
from user_tasks.models import UserTaskStatus
from user_tasks.signals import start_user_task
from user_tasks.signals import celery_app, create_user_task, start_user_task
from user_tasks.tasks import UserTask
from user_tasks.utils import extract_proto2_embed, extract_proto2_headers, proto2_to_proto1

User = auth.get_user_model()

Expand Down Expand Up @@ -189,6 +190,31 @@ def test_non_user_task_success(self):
statuses = UserTaskStatus.objects.all()
assert not statuses

def test_create_user_task_protocol_v2(self):
"""The create_user_task signal handler should work with Celery protocol version 2."""

original_protocol = getattr(celery_app.conf, 'task_protocol', 1)
celery_app.conf.task_protocol = 2
try:
body = (
[self.user.id, 'Argument'],
{},
{'callbacks': [], 'errbacks': [], 'task_chain': None, 'chord': None}
)
headers = {
'task_id': 'tid', 'retries': 0, 'eta': None, 'expires': None,
'group': None, 'timelimit': [None, None], 'task': 'test_signals.sample_task'
}
create_user_task(sender='test_signals.sample_task', body=body, headers=headers)
statuses = UserTaskStatus.objects.all()
assert len(statuses) == 1
status = statuses[0]
assert status.task_class == 'test_signals.sample_task'
assert status.user_id == self.user.id
assert status.name == 'SampleTask: Argument'
finally:
celery_app.conf.task_protocol = original_protocol

def _create_user_task(self, eager):
"""Create a task based on UserTaskMixin and verify some assertions about its corresponding status."""
result = sample_task.delay(self.user.id, 'Argument')
Expand Down Expand Up @@ -530,3 +556,68 @@ def test_connections_not_closed_when_we_cant_get_a_connection(self, mock_close_o
with mock.patch('user_tasks.signals.transaction.get_connection', side_effect=Exception):
start_user_task(sender=SampleTask)
assert mock_close_old_connections.called is False


class TestUtils:
"""
Unit tests for utility functions in user_tasks/utils.py.
"""

def test_extract_proto2_headers(self):
headers = extract_proto2_headers(
task_id='abc123', retries=2, eta='2025-05-30T12:00:00',
expires=None, group='group1', timelimit=[10, 20],
task='my_task', extra='ignored')
assert headers == {
'id': 'abc123',
'task': 'my_task',
'retries': 2,
'eta': '2025-05-30T12:00:00',
'expires': None,
'utc': True,
'taskset': 'group1',
'timelimit': [10, 20],
}

def test_extract_proto2_embed(self):
embed = extract_proto2_embed(
callbacks=['cb'], errbacks=['eb'], task_chain=['a', 'b'],
chord='chord1', extra='ignored')
assert embed == {
'callbacks': ['cb'],
'errbacks': ['eb'],
'chain': ['a', 'b'],
'chord': 'chord1',
}
embed = extract_proto2_embed()
assert embed == {
'callbacks': [],
'errbacks': [],
'chain': None,
'chord': None
}

def test_proto2_to_proto1(self, monkeypatch):
monkeypatch.setattr(
'user_tasks.utils.chain',
lambda x: f'chain({x})'
)
body = (
[1, 2],
{'foo': 'bar'},
{'callbacks': ['cb'], 'errbacks': ['eb'],
'task_chain': ['a'], 'chord': 'ch'}
)
headers = {
'task_id': 'tid', 'retries': 1, 'eta': 'eta', 'expires': 'exp',
'group': 'grp', 'timelimit': [1, 2], 'task': 't',
'extra': 'ignored'
}
result = proto2_to_proto1(body, headers)
assert result['id'] == 'tid'
assert result['args'] == [1, 2]
assert result['kwargs'] == {'foo': 'bar'}
assert result['callbacks'] == ['cb', "chain(['a'])"]
assert result['errbacks'] == ['eb']
assert 'chain' not in result
assert result['chord'] == 'ch'
48 changes: 29 additions & 19 deletions user_tasks/signals.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import logging
from uuid import uuid4

from celery import current_app as celery_app
from celery.signals import before_task_publish, task_failure, task_prerun, task_retry, task_success

from django.contrib.auth import get_user_model
Expand All @@ -16,40 +17,49 @@
from .exceptions import TaskCanceledException
from .models import UserTaskStatus
from .tasks import UserTaskMixin
from .utils import proto2_to_proto1

LOGGER = logging.getLogger(__name__)


@before_task_publish.connect
def create_user_task(sender=None, body=None, **kwargs):
def create_user_task(sender=None, body=None, headers=None, **kwargs):
"""
Create a :py:class:`UserTaskStatus` record for each :py:class:`UserTaskMixin`.

Also creates a :py:class:`UserTaskStatus` for each chain, chord, or group containing
the new :py:class:`UserTaskMixin`.

Supports Celery protocol v1 and v2.
"""
try:
task_class = import_string(sender)
except ImportError:
return
if issubclass(task_class.__class__, UserTaskMixin):
arguments_dict = task_class.arguments_as_dict(*body['args'], **body['kwargs'])
user_id = _get_user_id(arguments_dict)
task_id = body['id']
if body.get('callbacks', []):
_create_chain_entry(user_id, task_id, task_class, body['args'], body['kwargs'], body['callbacks'])
return
if body.get('chord', None):
_create_chord_entry(task_id, task_class, body, user_id)
return
parent = _get_or_create_group_parent(body, user_id)
name = task_class.generate_name(arguments_dict)
total_steps = task_class.calculate_total_steps(arguments_dict)
UserTaskStatus.objects.get_or_create(
task_id=task_id, defaults={'user_id': user_id, 'parent': parent, 'name': name, 'task_class': sender,
'total_steps': total_steps})
if parent:
parent.increment_total_steps(total_steps)

if celery_app.conf.task_protocol == 2 and isinstance(body, tuple):
body = proto2_to_proto1(body, headers or {})

if not issubclass(task_class.__class__, UserTaskMixin):
return

arguments_dict = task_class.arguments_as_dict(*body['args'], **body['kwargs'])
user_id = _get_user_id(arguments_dict)
task_id = body['id']
if body.get('callbacks', []):
_create_chain_entry(user_id, task_id, task_class, body['args'], body['kwargs'], body['callbacks'])
return
if body.get('chord', None):
_create_chord_entry(task_id, task_class, body, user_id)
return
parent = _get_or_create_group_parent(body, user_id)
name = task_class.generate_name(arguments_dict)
total_steps = task_class.calculate_total_steps(arguments_dict)
UserTaskStatus.objects.get_or_create(
task_id=task_id, defaults={'user_id': user_id, 'parent': parent, 'name': name, 'task_class': sender,
'total_steps': total_steps})
if parent:
parent.increment_total_steps(total_steps)


def _create_chain_entry(user_id, task_id, task_class, args, kwargs, callbacks, parent=None):
Expand Down
51 changes: 51 additions & 0 deletions user_tasks/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
"""
Utility functions for handling Celery task protocol compatibility.
"""

from celery import chain


def proto2_to_proto1(body, headers):
"""
Convert a protocol v2 task body and headers to protocol v1 format.
"""
args, kwargs, embed = body
embedded = extract_proto2_embed(**embed)
chained = embedded.pop("chain", None)
new_body = dict(
extract_proto2_headers(**headers),
args=args,
kwargs=kwargs,
**embedded,
)
if chained:
new_body["callbacks"].append(chain(chained))
return new_body


def extract_proto2_headers(task_id, retries, eta, expires, group, timelimit, task, **_):
"""
Extract relevant headers from protocol v2 format.
"""
return {
"id": task_id,
"task": task,
"retries": retries,
"eta": eta,
"expires": expires,
"utc": True,
"taskset": group,
"timelimit": timelimit,
}


def extract_proto2_embed(callbacks=None, errbacks=None, task_chain=None, chord=None, **_):
"""
Extract embedded task metadata.
"""
return {
"callbacks": callbacks or [],
"errbacks": errbacks or [],
"chain": task_chain,
"chord": chord,
}