From 784bee7f843fd830c3724c2ed55a6eda26fef785 Mon Sep 17 00:00:00 2001 From: Christoph Brand Date: Mon, 28 Feb 2022 19:15:24 +0100 Subject: [PATCH] Fix aiohttp ThreadedProducer driver python36, add unit tests (#277) * fix: rewrite non enabled unit tests * chore: add tests for aiokafka threadedproducer Add tests for basic setup of the threadedproducer, including startup, teardown and publishing messages tests. * fix: waited message publishing and proper teardown Fix message publishing with wait=True for the ThreadedProducer of aiokafka. Also add proper teardown logic for the task creation to allow proper unit testing. * fix: python3.6 support for ThreadedProducer replace asyncio.create_task to the create_task call in the thread_loop as asyncio.create_task is unsupported in python 3.6 * chore: wait till worker task is completed on teardown --- faust/transport/drivers/aiokafka.py | 23 ++- tests/unit/transport/drivers/test_aiokafka.py | 176 ++++++++++++++---- tests/unit/transport/test_producer.py | 122 ++++++------ 3 files changed, 225 insertions(+), 96 deletions(-) diff --git a/faust/transport/drivers/aiokafka.py b/faust/transport/drivers/aiokafka.py index 3ac986ea0..50fc24ff2 100644 --- a/faust/transport/drivers/aiokafka.py +++ b/faust/transport/drivers/aiokafka.py @@ -284,7 +284,9 @@ class ThreadedProducer(ServiceThread): _producer: Optional[aiokafka.AIOKafkaProducer] = None event_queue: Optional[asyncio.Queue] = None _default_producer: Optional[aiokafka.AIOKafkaProducer] = None + _push_events_task: Optional[asyncio.Task] = None app: None + stopped: bool def __init__( self, @@ -336,20 +338,29 @@ async def on_start(self) -> None: self.event_queue = asyncio.Queue() producer = self._producer = self._new_producer() await producer.start() - asyncio.create_task(self.push_events()) + self.stopped = False + self._push_events_task = self.thread_loop.create_task(self.push_events()) async def on_thread_stop(self) -> None: """Call when producer thread is stopping.""" logger.info("Stopping producer thread") await super().on_thread_stop() + self.stopped = True # when method queue is stopped, we can stop the consumer if self._producer is not None: await self.flush() await self._producer.stop() + if self._push_events_task is not None: + while not self._push_events_task.done(): + await asyncio.sleep(0.1) async def push_events(self): - while True: - event = await self.event_queue.get() + while not self.stopped: + try: + event = await asyncio.wait_for(self.event_queue.get(), timeout=0.1) + except asyncio.TimeoutError: + continue + self.app.sensors.on_threaded_producer_buffer_processed( app=self.app, size=self.event_queue.qsize() ) @@ -396,7 +407,11 @@ async def publish_message( timestamp_ms=timestamp_ms, headers=headers, ) - return await self._finalize_message(fut, ret) + fut.message.channel._on_published( + message=fut, state=state, producer=producer + ) + fut.set_result(ret) + return fut else: fut2 = cast( asyncio.Future, diff --git a/tests/unit/transport/drivers/test_aiokafka.py b/tests/unit/transport/drivers/test_aiokafka.py index d34753854..6876f7c9a 100644 --- a/tests/unit/transport/drivers/test_aiokafka.py +++ b/tests/unit/transport/drivers/test_aiokafka.py @@ -31,11 +31,13 @@ ConsumerStoppedError, Producer, ProducerSendError, + ThreadedProducer, Transport, credentials_to_aiokafka_auth, server_list, ) from faust.types import TP +from faust.types.tuples import FutureMessage, PendingMessage TP1 = TP("topic", 23) TP2 = TP("topix", 23) @@ -1287,7 +1289,7 @@ class MyPartitioner: my_partitioner = MyPartitioner() -class TestProducer: +class ProducerBaseTest: @pytest.fixture() def producer(self, *, app, _producer): producer = Producer(app.transport) @@ -1325,6 +1327,43 @@ def inner(): return inner + def assert_new_producer( + self, + producer, + acks=-1, + api_version="auto", + bootstrap_servers=["localhost:9092"], # noqa, + client_id=f"faust-{faust.__version__}", + compression_type=None, + linger_ms=0, + max_batch_size=16384, + max_request_size=1000000, + request_timeout_ms=1200000, + security_protocol="PLAINTEXT", + **kwargs, + ): + with patch("aiokafka.AIOKafkaProducer") as AIOKafkaProducer: + p = producer._new_producer() + assert p is AIOKafkaProducer.return_value + AIOKafkaProducer.assert_called_once_with( + acks=acks, + api_version=api_version, + bootstrap_servers=bootstrap_servers, + client_id=client_id, + compression_type=compression_type, + linger_ms=linger_ms, + max_batch_size=max_batch_size, + max_request_size=max_request_size, + request_timeout_ms=request_timeout_ms, + security_protocol=security_protocol, + loop=producer.loop, + partitioner=producer.partitioner, + transactional_id=None, + **kwargs, + ) + + +class TestProducer(ProducerBaseTest): @pytest.mark.conf(producer_partitioner=my_partitioner) def test_producer__uses_custom_partitioner(self, *, producer): assert producer.partitioner is my_partitioner @@ -1446,41 +1485,6 @@ def test__new_producer__using_settings(self, expected_args, *, app): producer = Producer(app.transport) self.assert_new_producer(producer, **expected_args) - def assert_new_producer( - self, - producer, - acks=-1, - api_version="auto", - bootstrap_servers=["localhost:9092"], # noqa, - client_id=f"faust-{faust.__version__}", - compression_type=None, - linger_ms=0, - max_batch_size=16384, - max_request_size=1000000, - request_timeout_ms=1200000, - security_protocol="PLAINTEXT", - **kwargs, - ): - with patch("aiokafka.AIOKafkaProducer") as AIOKafkaProducer: - p = producer._new_producer() - assert p is AIOKafkaProducer.return_value - AIOKafkaProducer.assert_called_once_with( - acks=acks, - api_version=api_version, - bootstrap_servers=bootstrap_servers, - client_id=client_id, - compression_type=compression_type, - linger_ms=linger_ms, - max_batch_size=max_batch_size, - max_request_size=max_request_size, - request_timeout_ms=request_timeout_ms, - security_protocol=security_protocol, - loop=producer.loop, - partitioner=producer.partitioner, - transactional_id=None, - **kwargs, - ) - @pytest.mark.asyncio async def test__new_producer__default(self, *, app): producer = Producer(app.transport) @@ -1748,6 +1752,106 @@ def test_supports_headers(self, *, producer): assert producer.supports_headers() +class TestThreadedProducer(ProducerBaseTest): + @pytest.fixture() + def threaded_producer(self, *, producer: Producer): + return producer.create_threaded_producer() + + @pytest.fixture() + def new_producer_mock(self, *, threaded_producer: ThreadedProducer): + mock = threaded_producer._new_producer = Mock( + name="_new_producer", + return_value=Mock( + start=AsyncMock(), + stop=AsyncMock(), + flush=AsyncMock(), + send_and_wait=AsyncMock(), + send=AsyncMock(), + ), + ) + return mock + + @pytest.fixture() + def mocked_producer(self, *, new_producer_mock: Mock): + return new_producer_mock.return_value + + @pytest.mark.asyncio + async def test_on_start( + self, *, threaded_producer: ThreadedProducer, mocked_producer: Mock, loop + ): + await threaded_producer.on_start() + try: + assert threaded_producer._producer is mocked_producer + threaded_producer._new_producer.assert_called_once_with() + mocked_producer.start.coro.assert_called_once_with() + finally: + await threaded_producer.start() + await threaded_producer.stop() + + @pytest.mark.asyncio + async def test_on_thread_stop( + self, *, threaded_producer: ThreadedProducer, mocked_producer: Mock, loop + ): + await threaded_producer.start() + await threaded_producer.on_thread_stop() + try: + mocked_producer.flush.coro.assert_called_once_with() + mocked_producer.stop.coro.assert_called_once_with() + finally: + await threaded_producer.stop() + + @pytest.mark.asyncio + async def test_publish_message( + self, *, threaded_producer: ThreadedProducer, mocked_producer: Mock, loop + ): + await threaded_producer.start() + try: + await threaded_producer.publish_message( + fut_other=FutureMessage( + PendingMessage( + channel=Mock(), + key="Test", + value="Test", + partition=None, + timestamp=None, + headers=None, + key_serializer=None, + value_serializer=None, + callback=None, + ) + ) + ) + mocked_producer.send.coro.assert_called_once() + finally: + await threaded_producer.stop() + + @pytest.mark.asyncio + async def test_publish_message_with_wait( + self, *, threaded_producer: ThreadedProducer, mocked_producer: Mock, loop + ): + await threaded_producer.start() + try: + await threaded_producer.publish_message( + wait=True, + fut_other=FutureMessage( + PendingMessage( + channel=Mock(), + key="Test", + value="Test", + partition=None, + timestamp=None, + headers=None, + key_serializer=None, + value_serializer=None, + callback=None, + ) + ), + ) + mocked_producer.send_and_wait.coro.assert_called_once() + finally: + await threaded_producer.stop() + + class TestTransport: @pytest.fixture() def transport(self, *, app): diff --git a/tests/unit/transport/test_producer.py b/tests/unit/transport/test_producer.py index 3279a79aa..87bb8697a 100644 --- a/tests/unit/transport/test_producer.py +++ b/tests/unit/transport/test_producer.py @@ -1,5 +1,5 @@ import asyncio -from typing import Any +from typing import Any, Optional from unittest.mock import PropertyMock import pytest @@ -62,58 +62,58 @@ async def on_send(fut): async def test_wait_until_ebb(self, *, buf): buf.max_messages = 10 - def create_send_pending_mock(max_messages): - sent_messages = 0 + flush_atmost_call_count = 0 - async def _inner(): - nonlocal sent_messages - if sent_messages < max_messages: - sent_messages += 1 - return - else: - await asyncio.Future() + async def flush_atmost(max_messages: Optional[int]) -> int: + assert ( + max_messages is None or max_messages == buf.max_messages + ), "Max messages set not to the max messages buffer" + nonlocal flush_atmost_call_count + flush_atmost_call_count += 1 - return create_send_pending_mock + await asyncio.sleep(0) + return 0 - buf._send_pending = create_send_pending_mock(10) + buf.flush_atmost = flush_atmost await buf.start() - self._put(buf, range(20)) - assert buf.size == 20 + original_size = buf.__class__.size + loop = asyncio.get_event_loop() + try: + buf.__class__.size = PropertyMock(return_value=20) - await buf.wait_until_ebb() - assert list(buf.pending._queue) == list(range(10, 20)) - assert buf.size == 10 + task = loop.create_task(buf.wait_until_ebb()) + await asyncio.sleep(0) + assert flush_atmost_call_count == 1 + assert not task.done(), ( + "The wait_until_ebb has been finished even " + "though flush atmost did not return" + ) + + buf.__class__.size = PropertyMock(return_value=10) + await asyncio.sleep(0) + assert task.done(), ( + "The wait_until_ebb did not complete even " + "though the size is beneath the max size" + ) - await buf.wait_until_ebb() - assert list(buf.pending._queue) == list(range(10, 20)) - assert buf.size == 10 + assert ( + flush_atmost_call_count > 0 + ), "The wait_until_ebb did not call the flush_atmost function" + task = loop.create_task(buf.wait_until_ebb()) + await asyncio.sleep(0) + assert task.done(), ( + "The wait_until_ebb function did not finish even " + "though the buffer is small enough" + ) + finally: + buf.__class__.size = original_size + await buf.stop() @pytest.mark.asyncio async def test_flush(self, *, buf): - def create_send_pending_mock(max_messages): - sent_messages = 0 - - async def _inner(): - nonlocal sent_messages - if sent_messages < max_messages: - sent_messages += 1 - return - else: - await asyncio.Future() - - return create_send_pending_mock - - buf._send_pending = create_send_pending_mock(10) - await buf.start() - - assert not buf.size + buf.flush_atmost = AsyncMock(return_value=0) await buf.flush() - - self._put(buf, range(10)) - assert buf.size == 10 - - await buf.flush() - assert not buf.size + buf.flush_atmost.assert_called() def _put(self, buf, items): for item in items: @@ -121,10 +121,13 @@ def _put(self, buf, items): @pytest.mark.asyncio async def test_flush_atmost(self, *, buf): + + sent_messages = 0 + def create_send_pending_mock(max_messages): - sent_messages = 0 + nonlocal sent_messages - async def _inner(): + async def _inner(*args: Any): nonlocal sent_messages if sent_messages < max_messages: sent_messages += 1 @@ -132,21 +135,28 @@ async def _inner(): else: await asyncio.Future() - return create_send_pending_mock + return _inner - assert await buf.flush_atmost(10) == 0 + await buf.start() + buf._send_pending = create_send_pending_mock(13) + + try: + assert await buf.flush_atmost(10) == 0 - self._put(buf, range(3)) - assert buf.size == 3 - assert await buf.flush_atmost(10) == 3 + self._put(buf, range(3)) + assert buf.size == 3 + assert await buf.flush_atmost(10) > 0 + assert sent_messages == 3 - self._put(buf, range(10)) - assert buf.size == 10 - assert (await buf.flush_atmost(4)) == 4 - assert buf.size == 6 + self._put(buf, range(10)) + assert buf.size == 10 + await buf.flush_atmost(4) - assert (await buf.flush_atmost(6)) == 6 - assert not buf.size + await buf.flush_atmost(6) + assert not buf.size + assert sent_messages == 13 + finally: + await buf.stop() @pytest.mark.asyncio async def test_flush_atmost_with_simulated_threaded_behavior(self, *, buf):