diff --git a/src/confluent_kafka/aio/producer/_buffer_timeout_manager.py b/src/confluent_kafka/aio/producer/_buffer_timeout_manager.py index 5b5386570..6b9497609 100644 --- a/src/confluent_kafka/aio/producer/_buffer_timeout_manager.py +++ b/src/confluent_kafka/aio/producer/_buffer_timeout_manager.py @@ -127,7 +127,4 @@ async def _flush_buffer_due_to_timeout(self): 2. Execute batches from the batch processor """ # Create batches from current buffer - batches = self._batch_processor.create_batches() - - # Execute batches with cleanup using the common function - await self._batch_processor._execute_batches(batches) + await self._batch_processor.flush_buffer() diff --git a/src/confluent_kafka/aio/producer/_producer_batch_processor.py b/src/confluent_kafka/aio/producer/_producer_batch_processor.py index d0fab08b8..7ad59edb8 100644 --- a/src/confluent_kafka/aio/producer/_producer_batch_processor.py +++ b/src/confluent_kafka/aio/producer/_producer_batch_processor.py @@ -138,8 +138,25 @@ async def flush_buffer(self, target_topic=None): # Create batches for processing batches = self.create_batches(target_topic) - # Execute batches with cleanup - await self._execute_batches(batches, target_topic) + # Clear the buffer immediately to prevent race conditions + if target_topic is None: + # Clear entire buffer since we're processing all messages + self.clear_buffer() + else: + # Clear only messages for the target topic that we're processing + self._clear_topic_from_buffer(target_topic) + + try: + # Execute batches with cleanup + await self._execute_batches(batches, target_topic) + except Exception: + # Add batches back to buffer on failure + try: + self._add_batches_back_to_buffer(batches) + except Exception: + logger.error(f"Error adding batches back to buffer on failure. messages might be lost: {batches}") + raise + raise async def _execute_batches(self, batches, target_topic=None): """Execute batches and handle cleanup after successful execution @@ -168,13 +185,33 @@ async def _execute_batches(self, batches, target_topic=None): # Re-raise the exception so caller knows the batch operation failed raise - # Clear successfully processed messages from buffer - if target_topic is None: - # Clear entire buffer since all messages were processed - self.clear_buffer() - else: - # Clear only messages for the target topic that were successfully processed - self._clear_topic_from_buffer(target_topic) + def _add_batches_back_to_buffer(self, batches): + """Add batches back to the buffer when execution fails + + Args: + batches: List of MessageBatch objects to add back to buffer + """ + for batch in batches: + # Add each message and its future back to the buffer + for i, message in enumerate(batch.messages): + # Reconstruct the original message data from the batch + msg_data = { + 'topic': batch.topic, + 'value': message.get('value'), + 'key': message.get('key'), + } + + # Add optional fields if present + if 'partition' in message: + msg_data['partition'] = message['partition'] + if 'timestamp' in message: + msg_data['timestamp'] = message['timestamp'] + if 'headers' in message: + msg_data['headers'] = message['headers'] + + # Add the message and its future back to the buffer + self._message_buffer.append(msg_data) + self._buffer_futures.append(batch.futures[i]) def _group_messages_by_topic_and_partition(self): """Group buffered messages by topic and partition for optimal batch processing diff --git a/tests/test_producer_batch_processor.py b/tests/test_producer_batch_processor.py index b6f091b19..36daa718e 100644 --- a/tests/test_producer_batch_processor.py +++ b/tests/test_producer_batch_processor.py @@ -400,6 +400,70 @@ def test_future_based_error_handling(self): with self.assertRaises(RuntimeError): future.result() + def test_add_batches_back_to_buffer_basic(self): + """Test adding batches back to buffer with basic message data""" + from confluent_kafka.aio.producer._message_batch import create_message_batch + + # Create test futures + future1 = asyncio.Future() + future2 = asyncio.Future() + + # Create test batch with basic message data + batch = create_message_batch( + topic='test-topic', + messages=[ + {'value': 'test1', 'key': 'key1'}, + {'value': 'test2', 'key': 'key2'} + ], + futures=[future1, future2], + partition=0 + ) + + # Ensure buffer is initially empty + self.assertTrue(self.batch_processor.is_buffer_empty()) + + # Add batch back to buffer + self.batch_processor._add_batches_back_to_buffer([batch]) + + # Verify buffer state + self.assertEqual(self.batch_processor.get_buffer_size(), 2) + self.assertFalse(self.batch_processor.is_buffer_empty()) + + # Verify message data was reconstructed correctly + self.assertEqual(self.batch_processor._message_buffer[0]['topic'], 'test-topic') + self.assertEqual(self.batch_processor._message_buffer[0]['value'], 'test1') + self.assertEqual(self.batch_processor._message_buffer[0]['key'], 'key1') + self.assertEqual(self.batch_processor._message_buffer[0]['partition'], 0) + + self.assertEqual(self.batch_processor._message_buffer[1]['topic'], 'test-topic') + self.assertEqual(self.batch_processor._message_buffer[1]['value'], 'test2') + self.assertEqual(self.batch_processor._message_buffer[1]['key'], 'key2') + self.assertEqual(self.batch_processor._message_buffer[1]['partition'], 0) + + # Verify futures are preserved + self.assertEqual(self.batch_processor._buffer_futures[0], future1) + self.assertEqual(self.batch_processor._buffer_futures[1], future2) + + def test_add_batches_back_to_buffer_empty_batch(self): + """Test adding empty batch back to buffer""" + from confluent_kafka.aio.producer._message_batch import create_message_batch + + # Create empty batch + batch = create_message_batch( + topic='test-topic', + messages=[], + futures=[], + partition=0 + ) + + initial_size = self.batch_processor.get_buffer_size() + + # Add empty batch back + self.batch_processor._add_batches_back_to_buffer([batch]) + + # Buffer size should remain unchanged + self.assertEqual(self.batch_processor.get_buffer_size(), initial_size) + if __name__ == '__main__': # Run all tests