Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Call the task_received signal. #85

Merged
merged 2 commits into from
Jan 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,15 @@
Changelog
#########

next
====

Improvements
------------

* Call the task received signal for ``Batches`` task. (`#85 <https://github.com/clokep/celery-batches/pull/85>`_)


0.8.1 (2023-06-27)
==================

Expand Down Expand Up @@ -34,6 +43,7 @@ Maintenance
* Support Python 3.11. (`#75 <https://github.com/clokep/celery-batches/pull/75>`_)
* Drop support for Python 3.7. (`#77 <https://github.com/clokep/celery-batches/pull/77>`_)


0.7 (2022-05-02)
================

Expand Down
7 changes: 5 additions & 2 deletions celery_batches/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from celery_batches.trace import apply_batches_task

from celery import VERSION as CELERY_VERSION
from celery import signals
from celery.app import Celery
from celery.app.task import Task
from celery.concurrency.base import BasePool
Expand Down Expand Up @@ -237,7 +238,7 @@ def task_message_handler(
else:
body, headers, decoded, utc = proto1_to_proto2(message, body)

request = Req(
req = Req(
message,
on_ack=ack,
on_reject=reject,
Expand All @@ -251,7 +252,9 @@ def task_message_handler(
utc=utc,
connection_errors=connection_errors,
)
put_buffer(request)
put_buffer(req)

signals.task_received.send(sender=consumer, request=req)

if self._tref is None: # first request starts flush timer.
self._tref = timer.call_repeatedly(self.flush_interval, flush_buffer)
Expand Down
37 changes: 25 additions & 12 deletions t/integration/test_batches.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
from celery.contrib.testing.tasks import ping
from celery.contrib.testing.worker import TestWorkController
from celery.result import allow_join_result
from celery.utils.dispatch import Signal
from celery.worker.consumer.consumer import Consumer
from celery.worker.request import Request

import pytest
Expand All @@ -18,20 +20,28 @@

class SignalCounter:
def __init__(
self, expected_calls: int, callback: Optional[Callable[..., None]] = None
self,
signal: Signal,
expected_calls: int,
callback: Optional[Callable[..., None]] = None,
):
self.signal = signal
signal.connect(self)
self.calls = 0
self.expected_calls = expected_calls
self.callback = callback

def __call__(self, sender: Union[Task, str], **kwargs: Any) -> None:
def __call__(self, sender: Union[Task, str, Consumer], **kwargs: Any) -> None:
if isinstance(sender, Task):
sender_name = sender.name
task_name = sender.name
elif isinstance(sender, Consumer):
assert self.signal == signals.task_received
task_name = kwargs["request"].name
else:
sender_name = sender
task_name = sender

# Ignore pings, those are used to ensure the worker processes tasks.
if sender_name == "celery.ping":
if task_name == "celery.ping":
return

self.calls += 1
Expand All @@ -41,7 +51,9 @@ def __call__(self, sender: Union[Task, str], **kwargs: Any) -> None:
self.callback(sender, **kwargs)

def assert_calls(self) -> None:
assert self.calls == self.expected_calls
assert (
self.calls == self.expected_calls
), f"Signal {self.signal.name} called incorrect number of times."


def _wait_for_ping(ping_task_timeout: float = 10.0) -> None:
Expand Down Expand Up @@ -144,21 +156,23 @@ def test_signals(celery_app: Celery, celery_worker: TestWorkController) -> None:
# Each task request gets published separately.
(signals.before_task_publish, 2),
(signals.after_task_publish, 2),
# The task only runs a single time.
(signals.task_sent, 2),
(signals.task_received, 2),
# The Batch task only runs a single time.
(signals.task_prerun, 1),
(signals.task_postrun, 1),
(signals.task_success, 1),
# Other task signals are not implemented.
(signals.task_retry, 0),
(signals.task_success, 1),
(signals.task_failure, 0),
(signals.task_revoked, 0),
(signals.task_internal_error, 0),
(signals.task_unknown, 0),
(signals.task_rejected, 0),
)
signal_counters = []
for sig, expected_count in checks:
counter = SignalCounter(expected_count)
sig.connect(counter)
counter = SignalCounter(sig, expected_count)
signal_counters.append(counter)

# The batch runs after 2 task calls.
Expand All @@ -182,8 +196,7 @@ def test_current_task(celery_app: Celery, celery_worker: TestWorkController) ->
def signal(sender: Union[Task, str], **kwargs: Any) -> None:
assert celery_app.current_task.name == "t.integration.tasks.add"

counter = SignalCounter(1, signal)
signals.task_prerun.connect(counter)
counter = SignalCounter(signals.task_prerun, 1, signal)

# The batch runs after 2 task calls.
result_1 = add.delay(1)
Expand Down
Loading