From bf91e1376ff273d2fa4ffff137337f940ddefe16 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Mon, 8 Jan 2024 13:01:02 -0500 Subject: [PATCH 1/2] Refactor asserting signal calls. --- t/integration/test_batches.py | 33 +++++++++++++++++++++++---------- 1 file changed, 23 insertions(+), 10 deletions(-) diff --git a/t/integration/test_batches.py b/t/integration/test_batches.py index 8169179..61bf331 100644 --- a/t/integration/test_batches.py +++ b/t/integration/test_batches.py @@ -9,6 +9,8 @@ from celery.contrib.testing.tasks import ping from celery.contrib.testing.worker import TestWorkController from celery.result import allow_join_result +from celery.utils.dispatch import Signal +from celery.worker.consumer.consumer import Consumer from celery.worker.request import Request import pytest @@ -18,20 +20,28 @@ class SignalCounter: def __init__( - self, expected_calls: int, callback: Optional[Callable[..., None]] = None + self, + signal: Signal, + expected_calls: int, + callback: Optional[Callable[..., None]] = None, ): + self.signal = signal + signal.connect(self) self.calls = 0 self.expected_calls = expected_calls self.callback = callback - def __call__(self, sender: Union[Task, str], **kwargs: Any) -> None: + def __call__(self, sender: Union[Task, str, Consumer], **kwargs: Any) -> None: if isinstance(sender, Task): - sender_name = sender.name + task_name = sender.name + elif isinstance(sender, Consumer): + assert self.signal == signals.task_received + task_name = kwargs["request"].name else: - sender_name = sender + task_name = sender # Ignore pings, those are used to ensure the worker processes tasks. - if sender_name == "celery.ping": + if task_name == "celery.ping": return self.calls += 1 @@ -41,7 +51,9 @@ def __call__(self, sender: Union[Task, str], **kwargs: Any) -> None: self.callback(sender, **kwargs) def assert_calls(self) -> None: - assert self.calls == self.expected_calls + assert ( + self.calls == self.expected_calls + ), f"Signal {self.signal.name} called incorrect number of times." def _wait_for_ping(ping_task_timeout: float = 10.0) -> None: @@ -144,21 +156,23 @@ def test_signals(celery_app: Celery, celery_worker: TestWorkController) -> None: # Each task request gets published separately. (signals.before_task_publish, 2), (signals.after_task_publish, 2), + (signals.task_sent, 2), # The task only runs a single time. (signals.task_prerun, 1), (signals.task_postrun, 1), + (signals.task_received, 0), # Other task signals are not implemented. (signals.task_retry, 0), (signals.task_success, 1), (signals.task_failure, 0), (signals.task_revoked, 0), + (signals.task_internal_error, 0), (signals.task_unknown, 0), (signals.task_rejected, 0), ) signal_counters = [] for sig, expected_count in checks: - counter = SignalCounter(expected_count) - sig.connect(counter) + counter = SignalCounter(sig, expected_count) signal_counters.append(counter) # The batch runs after 2 task calls. @@ -182,8 +196,7 @@ def test_current_task(celery_app: Celery, celery_worker: TestWorkController) -> def signal(sender: Union[Task, str], **kwargs: Any) -> None: assert celery_app.current_task.name == "t.integration.tasks.add" - counter = SignalCounter(1, signal) - signals.task_prerun.connect(counter) + counter = SignalCounter(signals.task_prerun, 1, signal) # The batch runs after 2 task calls. result_1 = add.delay(1) From fbdac49306be018c3d89d98539f631cf89d5e42a Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Mon, 8 Jan 2024 13:18:11 -0500 Subject: [PATCH 2/2] Call the task_received signal. --- CHANGELOG.rst | 10 ++++++++++ celery_batches/__init__.py | 7 +++++-- t/integration/test_batches.py | 6 +++--- 3 files changed, 18 insertions(+), 5 deletions(-) diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 67bcacb..97be071 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -3,6 +3,15 @@ Changelog ######### +next +==== + +Improvements +------------ + +* Call the task received signal for ``Batches`` task. (`#85 `_) + + 0.8.1 (2023-06-27) ================== @@ -34,6 +43,7 @@ Maintenance * Support Python 3.11. (`#75 `_) * Drop support for Python 3.7. (`#77 `_) + 0.7 (2022-05-02) ================ diff --git a/celery_batches/__init__.py b/celery_batches/__init__.py index b8aa179..7ce1cbf 100644 --- a/celery_batches/__init__.py +++ b/celery_batches/__init__.py @@ -17,6 +17,7 @@ from celery_batches.trace import apply_batches_task from celery import VERSION as CELERY_VERSION +from celery import signals from celery.app import Celery from celery.app.task import Task from celery.concurrency.base import BasePool @@ -237,7 +238,7 @@ def task_message_handler( else: body, headers, decoded, utc = proto1_to_proto2(message, body) - request = Req( + req = Req( message, on_ack=ack, on_reject=reject, @@ -251,7 +252,9 @@ def task_message_handler( utc=utc, connection_errors=connection_errors, ) - put_buffer(request) + put_buffer(req) + + signals.task_received.send(sender=consumer, request=req) if self._tref is None: # first request starts flush timer. self._tref = timer.call_repeatedly(self.flush_interval, flush_buffer) diff --git a/t/integration/test_batches.py b/t/integration/test_batches.py index 61bf331..2283d07 100644 --- a/t/integration/test_batches.py +++ b/t/integration/test_batches.py @@ -157,13 +157,13 @@ def test_signals(celery_app: Celery, celery_worker: TestWorkController) -> None: (signals.before_task_publish, 2), (signals.after_task_publish, 2), (signals.task_sent, 2), - # The task only runs a single time. + (signals.task_received, 2), + # The Batch task only runs a single time. (signals.task_prerun, 1), (signals.task_postrun, 1), - (signals.task_received, 0), + (signals.task_success, 1), # Other task signals are not implemented. (signals.task_retry, 0), - (signals.task_success, 1), (signals.task_failure, 0), (signals.task_revoked, 0), (signals.task_internal_error, 0),