diff --git a/tests/test_signals.py b/tests/test_signals.py index da8ebf76..9b0e5a0e 100644 --- a/tests/test_signals.py +++ b/tests/test_signals.py @@ -18,8 +18,9 @@ from user_tasks import user_task_stopped from user_tasks.models import UserTaskStatus -from user_tasks.signals import start_user_task +from user_tasks.signals import celery_app, create_user_task, start_user_task from user_tasks.tasks import UserTask +from user_tasks.utils import extract_proto2_embed, extract_proto2_headers, proto2_to_proto1 User = auth.get_user_model() @@ -189,6 +190,31 @@ def test_non_user_task_success(self): statuses = UserTaskStatus.objects.all() assert not statuses + def test_create_user_task_protocol_v2(self): + """The create_user_task signal handler should work with Celery protocol version 2.""" + + original_protocol = getattr(celery_app.conf, 'task_protocol', 1) + celery_app.conf.task_protocol = 2 + try: + body = ( + [self.user.id, 'Argument'], + {}, + {'callbacks': [], 'errbacks': [], 'task_chain': None, 'chord': None} + ) + headers = { + 'task_id': 'tid', 'retries': 0, 'eta': None, 'expires': None, + 'group': None, 'timelimit': [None, None], 'task': 'test_signals.sample_task' + } + create_user_task(sender='test_signals.sample_task', body=body, headers=headers) + statuses = UserTaskStatus.objects.all() + assert len(statuses) == 1 + status = statuses[0] + assert status.task_class == 'test_signals.sample_task' + assert status.user_id == self.user.id + assert status.name == 'SampleTask: Argument' + finally: + celery_app.conf.task_protocol = original_protocol + def _create_user_task(self, eager): """Create a task based on UserTaskMixin and verify some assertions about its corresponding status.""" result = sample_task.delay(self.user.id, 'Argument') @@ -530,3 +556,68 @@ def test_connections_not_closed_when_we_cant_get_a_connection(self, mock_close_o with mock.patch('user_tasks.signals.transaction.get_connection', side_effect=Exception): start_user_task(sender=SampleTask) assert mock_close_old_connections.called is False + + +class TestUtils: + """ + Unit tests for utility functions in user_tasks/utils.py. + """ + + def test_extract_proto2_headers(self): + headers = extract_proto2_headers( + task_id='abc123', retries=2, eta='2025-05-30T12:00:00', + expires=None, group='group1', timelimit=[10, 20], + task='my_task', extra='ignored') + assert headers == { + 'id': 'abc123', + 'task': 'my_task', + 'retries': 2, + 'eta': '2025-05-30T12:00:00', + 'expires': None, + 'utc': True, + 'taskset': 'group1', + 'timelimit': [10, 20], + } + + def test_extract_proto2_embed(self): + embed = extract_proto2_embed( + callbacks=['cb'], errbacks=['eb'], task_chain=['a', 'b'], + chord='chord1', extra='ignored') + assert embed == { + 'callbacks': ['cb'], + 'errbacks': ['eb'], + 'chain': ['a', 'b'], + 'chord': 'chord1', + } + embed = extract_proto2_embed() + assert embed == { + 'callbacks': [], + 'errbacks': [], + 'chain': None, + 'chord': None + } + + def test_proto2_to_proto1(self, monkeypatch): + monkeypatch.setattr( + 'user_tasks.utils.chain', + lambda x: f'chain({x})' + ) + body = ( + [1, 2], + {'foo': 'bar'}, + {'callbacks': ['cb'], 'errbacks': ['eb'], + 'task_chain': ['a'], 'chord': 'ch'} + ) + headers = { + 'task_id': 'tid', 'retries': 1, 'eta': 'eta', 'expires': 'exp', + 'group': 'grp', 'timelimit': [1, 2], 'task': 't', + 'extra': 'ignored' + } + result = proto2_to_proto1(body, headers) + assert result['id'] == 'tid' + assert result['args'] == [1, 2] + assert result['kwargs'] == {'foo': 'bar'} + assert result['callbacks'] == ['cb', "chain(['a'])"] + assert result['errbacks'] == ['eb'] + assert 'chain' not in result + assert result['chord'] == 'ch' diff --git a/user_tasks/signals.py b/user_tasks/signals.py index 1ef3f8d9..233453f9 100644 --- a/user_tasks/signals.py +++ b/user_tasks/signals.py @@ -5,6 +5,7 @@ import logging from uuid import uuid4 +from celery import current_app as celery_app from celery.signals import before_task_publish, task_failure, task_prerun, task_retry, task_success from django.contrib.auth import get_user_model @@ -16,40 +17,49 @@ from .exceptions import TaskCanceledException from .models import UserTaskStatus from .tasks import UserTaskMixin +from .utils import proto2_to_proto1 LOGGER = logging.getLogger(__name__) @before_task_publish.connect -def create_user_task(sender=None, body=None, **kwargs): +def create_user_task(sender=None, body=None, headers=None, **kwargs): """ Create a :py:class:`UserTaskStatus` record for each :py:class:`UserTaskMixin`. Also creates a :py:class:`UserTaskStatus` for each chain, chord, or group containing the new :py:class:`UserTaskMixin`. + + Supports Celery protocol v1 and v2. """ try: task_class = import_string(sender) except ImportError: return - if issubclass(task_class.__class__, UserTaskMixin): - arguments_dict = task_class.arguments_as_dict(*body['args'], **body['kwargs']) - user_id = _get_user_id(arguments_dict) - task_id = body['id'] - if body.get('callbacks', []): - _create_chain_entry(user_id, task_id, task_class, body['args'], body['kwargs'], body['callbacks']) - return - if body.get('chord', None): - _create_chord_entry(task_id, task_class, body, user_id) - return - parent = _get_or_create_group_parent(body, user_id) - name = task_class.generate_name(arguments_dict) - total_steps = task_class.calculate_total_steps(arguments_dict) - UserTaskStatus.objects.get_or_create( - task_id=task_id, defaults={'user_id': user_id, 'parent': parent, 'name': name, 'task_class': sender, - 'total_steps': total_steps}) - if parent: - parent.increment_total_steps(total_steps) + + if celery_app.conf.task_protocol == 2 and isinstance(body, tuple): + body = proto2_to_proto1(body, headers or {}) + + if not issubclass(task_class.__class__, UserTaskMixin): + return + + arguments_dict = task_class.arguments_as_dict(*body['args'], **body['kwargs']) + user_id = _get_user_id(arguments_dict) + task_id = body['id'] + if body.get('callbacks', []): + _create_chain_entry(user_id, task_id, task_class, body['args'], body['kwargs'], body['callbacks']) + return + if body.get('chord', None): + _create_chord_entry(task_id, task_class, body, user_id) + return + parent = _get_or_create_group_parent(body, user_id) + name = task_class.generate_name(arguments_dict) + total_steps = task_class.calculate_total_steps(arguments_dict) + UserTaskStatus.objects.get_or_create( + task_id=task_id, defaults={'user_id': user_id, 'parent': parent, 'name': name, 'task_class': sender, + 'total_steps': total_steps}) + if parent: + parent.increment_total_steps(total_steps) def _create_chain_entry(user_id, task_id, task_class, args, kwargs, callbacks, parent=None): diff --git a/user_tasks/utils.py b/user_tasks/utils.py new file mode 100644 index 00000000..4b553491 --- /dev/null +++ b/user_tasks/utils.py @@ -0,0 +1,51 @@ +""" +Utility functions for handling Celery task protocol compatibility. +""" + +from celery import chain + + +def proto2_to_proto1(body, headers): + """ + Convert a protocol v2 task body and headers to protocol v1 format. + """ + args, kwargs, embed = body + embedded = extract_proto2_embed(**embed) + chained = embedded.pop("chain", None) + new_body = dict( + extract_proto2_headers(**headers), + args=args, + kwargs=kwargs, + **embedded, + ) + if chained: + new_body["callbacks"].append(chain(chained)) + return new_body + + +def extract_proto2_headers(task_id, retries, eta, expires, group, timelimit, task, **_): + """ + Extract relevant headers from protocol v2 format. + """ + return { + "id": task_id, + "task": task, + "retries": retries, + "eta": eta, + "expires": expires, + "utc": True, + "taskset": group, + "timelimit": timelimit, + } + + +def extract_proto2_embed(callbacks=None, errbacks=None, task_chain=None, chord=None, **_): + """ + Extract embedded task metadata. + """ + return { + "callbacks": callbacks or [], + "errbacks": errbacks or [], + "chain": task_chain, + "chord": chord, + }