From 54f4369322a5d586468ffaab61a321aca7136bd8 Mon Sep 17 00:00:00 2001 From: Lyn Nagara Date: Wed, 8 Nov 2023 13:59:38 -0800 Subject: [PATCH] fix(dlq): RunTaskWithMultiprocessing supports forwarding downstream invalid message (#301) Invalid messages raised from strategies downstream of RunTaskWithMultiprocessing are not correctly handled currently. Currently messages from that batch will be re-submitted to the next step multiple times. This change ensures they are correctly re-raised and avoids the duplicate messages downstream. --- .../run_task_with_multiprocessing.py | 27 ++++++++++++++++--- .../test_run_task_with_multiprocessing.py | 24 +++++++++++++++++ 2 files changed, 47 insertions(+), 4 deletions(-) diff --git a/arroyo/processing/strategies/run_task_with_multiprocessing.py b/arroyo/processing/strategies/run_task_with_multiprocessing.py index 107b6d83..5d637936 100644 --- a/arroyo/processing/strategies/run_task_with_multiprocessing.py +++ b/arroyo/processing/strategies/run_task_with_multiprocessing.py @@ -223,7 +223,6 @@ def parallel_run_task_worker_apply( output_block: SharedMemory, start_index: int = 0, ) -> ParallelRunTaskResult[TResult]: - valid_messages_transformed: MessageBatch[ Union[InvalidMessage, Message[Union[FilteredPayload, TResult]]] ] = MessageBatch(output_block) @@ -614,6 +613,13 @@ def __check_for_results_impl(self, timeout: Optional[float] = None) -> None: try: self.__next_step.poll() + except InvalidMessage as e: + # For the next invocation of __check_for_results, start at this message + result.valid_messages_transformed.reset_iterator(idx) + self.__invalid_messages.append(e) + raise e + + try: self.__next_step.submit(message) except MessageRejected: @@ -624,6 +630,12 @@ def __check_for_results_impl(self, timeout: Optional[float] = None) -> None: "arroyo.strategies.run_task_with_multiprocessing.batch.backpressure" ) raise NextStepTimeoutError() + except InvalidMessage as e: + # For the next invocation of __check_for_results, skip over this message + # since we do not want to re-submit it. + result.valid_messages_transformed.reset_iterator(idx + 1) + self.__invalid_messages.append(e) + raise e if result.next_index_to_process != len(input_batch): self.__metrics.increment( @@ -770,14 +782,21 @@ def terminate(self) -> None: self.__next_step.terminate() def join(self, timeout: Optional[float] = None) -> None: + start_join = time.time() deadline = time.time() + timeout if timeout is not None else None self.__forward_invalid_offsets() logger.debug("Waiting for %s batches...", len(self.__processes)) - self.__check_for_results( - timeout=timeout, - ) + while True: + elapsed = time.time() - start_join + try: + self.__check_for_results( + timeout=timeout - elapsed if timeout is not None else None, + ) + break + except InvalidMessage: + raise logger.debug("Waiting for %s...", self.__pool) self.__pool.terminate() diff --git a/tests/processing/strategies/test_run_task_with_multiprocessing.py b/tests/processing/strategies/test_run_task_with_multiprocessing.py index 82b5fbaa..d2c4e259 100644 --- a/tests/processing/strategies/test_run_task_with_multiprocessing.py +++ b/tests/processing/strategies/test_run_task_with_multiprocessing.py @@ -655,3 +655,27 @@ def test_multiprocessing_with_invalid_message() -> None: strategy.close() with pytest.raises(InvalidMessage): strategy.join(timeout=3) + + +def test_reraise_invalid_message() -> None: + next_step = Mock() + partition = Partition(Topic("test"), 0) + offset = 5 + next_step.poll.side_effect = InvalidMessage(partition, offset) + + strategy = RunTaskWithMultiprocessing( + run_multiply_times_two, + next_step, + num_processes=2, + max_batch_size=1, + max_batch_time=60, + ) + + strategy.submit(Message(Value(KafkaPayload(None, b"x" * 10, []), {}))) + + with pytest.raises(InvalidMessage): + strategy.poll() + + next_step.poll.reset_mock(side_effect=True) + strategy.close() + strategy.join()