diff --git a/arroyo/dlq.py b/arroyo/dlq.py index 7519934d..d8dcd4ec 100644 --- a/arroyo/dlq.py +++ b/arroyo/dlq.py @@ -65,6 +65,9 @@ def __eq__(self, other: Any) -> bool: and self.needs_commit == other.needs_commit ) + def __reduce__(self) -> Tuple[Any, Tuple[Any, ...]]: + return self.__class__, (self.partition, self.offset, self.needs_commit) + @dataclass(frozen=True) class DlqLimit: diff --git a/tests/processing/strategies/test_run_task_with_multiprocessing.py b/tests/processing/strategies/test_run_task_with_multiprocessing.py index 63922a04..82b5fbaa 100644 --- a/tests/processing/strategies/test_run_task_with_multiprocessing.py +++ b/tests/processing/strategies/test_run_task_with_multiprocessing.py @@ -7,6 +7,7 @@ import pytest from arroyo.backends.kafka import KafkaPayload +from arroyo.dlq import InvalidMessage from arroyo.processing.strategies import MessageRejected from arroyo.processing.strategies.run_task_with_multiprocessing import ( MessageBatch, @@ -631,3 +632,26 @@ def test_output_block_resizing_without_limits() -> None: ) in TestingMetricsBackend.calls ) + + +def message_processor_raising_invalid_message(x: Message[KafkaPayload]) -> KafkaPayload: + raise InvalidMessage(Partition(topic=Topic("test_topic"), index=0), offset=1000) + + +def test_multiprocessing_with_invalid_message() -> None: + next_step = Mock() + + strategy = RunTaskWithMultiprocessing( + message_processor_raising_invalid_message, + next_step, + num_processes=2, + max_batch_size=1, + max_batch_time=60, + ) + + strategy.submit(Message(Value(KafkaPayload(None, b"x" * 10, []), {}))) + + strategy.poll() + strategy.close() + with pytest.raises(InvalidMessage): + strategy.join(timeout=3) diff --git a/tests/test_dlq.py b/tests/test_dlq.py index 6d6f933b..1b024521 100644 --- a/tests/test_dlq.py +++ b/tests/test_dlq.py @@ -1,3 +1,4 @@ +import pickle from datetime import datetime from typing import Generator from unittest.mock import ANY @@ -109,3 +110,10 @@ def test_dlq_policy_wrapper() -> None: ) wrapper.produce(message) wrapper.flush({partition: 11}) + + +def test_invalid_message_pickleable() -> None: + exc = InvalidMessage(Partition(Topic("test_topic"), 0), 2) + pickled_exc = pickle.dumps(exc) + unpickled_exc = pickle.loads(pickled_exc) + assert exc == unpickled_exc