Skip to content

Commit

Permalink
Fix aiohttp ThreadedProducer driver python36, add unit tests (#277)
Browse files Browse the repository at this point in the history
* fix: rewrite non enabled unit tests

* chore: add tests for aiokafka threadedproducer

Add tests for basic setup of the threadedproducer, including
startup, teardown and publishing messages tests.

* fix: waited message publishing and proper teardown

Fix message publishing with wait=True for the ThreadedProducer
of aiokafka.

Also add proper teardown logic for the task creation to allow
proper unit testing.

* fix: python3.6 support for ThreadedProducer

replace asyncio.create_task to the create_task call in the
thread_loop as asyncio.create_task is unsupported in
python 3.6

* chore: wait till worker task is completed on teardown
  • Loading branch information
cbrand committed Feb 28, 2022
1 parent 89c1614 commit 784bee7
Show file tree
Hide file tree
Showing 3 changed files with 225 additions and 96 deletions.
23 changes: 19 additions & 4 deletions faust/transport/drivers/aiokafka.py
Expand Up @@ -284,7 +284,9 @@ class ThreadedProducer(ServiceThread):
_producer: Optional[aiokafka.AIOKafkaProducer] = None
event_queue: Optional[asyncio.Queue] = None
_default_producer: Optional[aiokafka.AIOKafkaProducer] = None
_push_events_task: Optional[asyncio.Task] = None
app: None
stopped: bool

def __init__(
self,
Expand Down Expand Up @@ -336,20 +338,29 @@ async def on_start(self) -> None:
self.event_queue = asyncio.Queue()
producer = self._producer = self._new_producer()
await producer.start()
asyncio.create_task(self.push_events())
self.stopped = False
self._push_events_task = self.thread_loop.create_task(self.push_events())

async def on_thread_stop(self) -> None:
"""Call when producer thread is stopping."""
logger.info("Stopping producer thread")
await super().on_thread_stop()
self.stopped = True
# when method queue is stopped, we can stop the consumer
if self._producer is not None:
await self.flush()
await self._producer.stop()
if self._push_events_task is not None:
while not self._push_events_task.done():
await asyncio.sleep(0.1)

async def push_events(self):
while True:
event = await self.event_queue.get()
while not self.stopped:
try:
event = await asyncio.wait_for(self.event_queue.get(), timeout=0.1)
except asyncio.TimeoutError:
continue

self.app.sensors.on_threaded_producer_buffer_processed(
app=self.app, size=self.event_queue.qsize()
)
Expand Down Expand Up @@ -396,7 +407,11 @@ async def publish_message(
timestamp_ms=timestamp_ms,
headers=headers,
)
return await self._finalize_message(fut, ret)
fut.message.channel._on_published(
message=fut, state=state, producer=producer
)
fut.set_result(ret)
return fut
else:
fut2 = cast(
asyncio.Future,
Expand Down
176 changes: 140 additions & 36 deletions tests/unit/transport/drivers/test_aiokafka.py
Expand Up @@ -31,11 +31,13 @@
ConsumerStoppedError,
Producer,
ProducerSendError,
ThreadedProducer,
Transport,
credentials_to_aiokafka_auth,
server_list,
)
from faust.types import TP
from faust.types.tuples import FutureMessage, PendingMessage

TP1 = TP("topic", 23)
TP2 = TP("topix", 23)
Expand Down Expand Up @@ -1287,7 +1289,7 @@ class MyPartitioner:
my_partitioner = MyPartitioner()


class TestProducer:
class ProducerBaseTest:
@pytest.fixture()
def producer(self, *, app, _producer):
producer = Producer(app.transport)
Expand Down Expand Up @@ -1325,6 +1327,43 @@ def inner():

return inner

def assert_new_producer(
self,
producer,
acks=-1,
api_version="auto",
bootstrap_servers=["localhost:9092"], # noqa,
client_id=f"faust-{faust.__version__}",
compression_type=None,
linger_ms=0,
max_batch_size=16384,
max_request_size=1000000,
request_timeout_ms=1200000,
security_protocol="PLAINTEXT",
**kwargs,
):
with patch("aiokafka.AIOKafkaProducer") as AIOKafkaProducer:
p = producer._new_producer()
assert p is AIOKafkaProducer.return_value
AIOKafkaProducer.assert_called_once_with(
acks=acks,
api_version=api_version,
bootstrap_servers=bootstrap_servers,
client_id=client_id,
compression_type=compression_type,
linger_ms=linger_ms,
max_batch_size=max_batch_size,
max_request_size=max_request_size,
request_timeout_ms=request_timeout_ms,
security_protocol=security_protocol,
loop=producer.loop,
partitioner=producer.partitioner,
transactional_id=None,
**kwargs,
)


class TestProducer(ProducerBaseTest):
@pytest.mark.conf(producer_partitioner=my_partitioner)
def test_producer__uses_custom_partitioner(self, *, producer):
assert producer.partitioner is my_partitioner
Expand Down Expand Up @@ -1446,41 +1485,6 @@ def test__new_producer__using_settings(self, expected_args, *, app):
producer = Producer(app.transport)
self.assert_new_producer(producer, **expected_args)

def assert_new_producer(
self,
producer,
acks=-1,
api_version="auto",
bootstrap_servers=["localhost:9092"], # noqa,
client_id=f"faust-{faust.__version__}",
compression_type=None,
linger_ms=0,
max_batch_size=16384,
max_request_size=1000000,
request_timeout_ms=1200000,
security_protocol="PLAINTEXT",
**kwargs,
):
with patch("aiokafka.AIOKafkaProducer") as AIOKafkaProducer:
p = producer._new_producer()
assert p is AIOKafkaProducer.return_value
AIOKafkaProducer.assert_called_once_with(
acks=acks,
api_version=api_version,
bootstrap_servers=bootstrap_servers,
client_id=client_id,
compression_type=compression_type,
linger_ms=linger_ms,
max_batch_size=max_batch_size,
max_request_size=max_request_size,
request_timeout_ms=request_timeout_ms,
security_protocol=security_protocol,
loop=producer.loop,
partitioner=producer.partitioner,
transactional_id=None,
**kwargs,
)

@pytest.mark.asyncio
async def test__new_producer__default(self, *, app):
producer = Producer(app.transport)
Expand Down Expand Up @@ -1748,6 +1752,106 @@ def test_supports_headers(self, *, producer):
assert producer.supports_headers()


class TestThreadedProducer(ProducerBaseTest):
@pytest.fixture()
def threaded_producer(self, *, producer: Producer):
return producer.create_threaded_producer()

@pytest.fixture()
def new_producer_mock(self, *, threaded_producer: ThreadedProducer):
mock = threaded_producer._new_producer = Mock(
name="_new_producer",
return_value=Mock(
start=AsyncMock(),
stop=AsyncMock(),
flush=AsyncMock(),
send_and_wait=AsyncMock(),
send=AsyncMock(),
),
)
return mock

@pytest.fixture()
def mocked_producer(self, *, new_producer_mock: Mock):
return new_producer_mock.return_value

@pytest.mark.asyncio
async def test_on_start(
self, *, threaded_producer: ThreadedProducer, mocked_producer: Mock, loop
):
await threaded_producer.on_start()
try:
assert threaded_producer._producer is mocked_producer
threaded_producer._new_producer.assert_called_once_with()
mocked_producer.start.coro.assert_called_once_with()
finally:
await threaded_producer.start()
await threaded_producer.stop()

@pytest.mark.asyncio
async def test_on_thread_stop(
self, *, threaded_producer: ThreadedProducer, mocked_producer: Mock, loop
):
await threaded_producer.start()
await threaded_producer.on_thread_stop()
try:
mocked_producer.flush.coro.assert_called_once_with()
mocked_producer.stop.coro.assert_called_once_with()
finally:
await threaded_producer.stop()

@pytest.mark.asyncio
async def test_publish_message(
self, *, threaded_producer: ThreadedProducer, mocked_producer: Mock, loop
):
await threaded_producer.start()
try:
await threaded_producer.publish_message(
fut_other=FutureMessage(
PendingMessage(
channel=Mock(),
key="Test",
value="Test",
partition=None,
timestamp=None,
headers=None,
key_serializer=None,
value_serializer=None,
callback=None,
)
)
)
mocked_producer.send.coro.assert_called_once()
finally:
await threaded_producer.stop()

@pytest.mark.asyncio
async def test_publish_message_with_wait(
self, *, threaded_producer: ThreadedProducer, mocked_producer: Mock, loop
):
await threaded_producer.start()
try:
await threaded_producer.publish_message(
wait=True,
fut_other=FutureMessage(
PendingMessage(
channel=Mock(),
key="Test",
value="Test",
partition=None,
timestamp=None,
headers=None,
key_serializer=None,
value_serializer=None,
callback=None,
)
),
)
mocked_producer.send_and_wait.coro.assert_called_once()
finally:
await threaded_producer.stop()


class TestTransport:
@pytest.fixture()
def transport(self, *, app):
Expand Down

0 comments on commit 784bee7

Please sign in to comment.