Skip to content

Commit

Permalink
Adding threaded producer (#68)
Browse files Browse the repository at this point in the history
* Add backpressure to slow stream processing to avoid filing up the buffer

* Add backpressure to slow stream processing to avoid filing up the buffer

* fix for crash on _commit due to rebalance

* Adding changelog producer thread

* Adding changelog producer thread

* Adding changelog producer thread

* fix for crash on _commit due to rebalance

* Custom faust build. Has producer instrumentation and agent coro name added for logging

* Adding producer in new thread

* updating version to 28

* updating version to 29

* adding producer send metrics

* gracefully shutting down producer thread

* publish on send

* producer threaded flag

* minor fixes

* minor fixes

* minor fixes

Co-authored-by: Aditya Vaderiyattil <aditya.vad@gmail.com>
  • Loading branch information
patkivikram and appu1232 committed Jan 21, 2021
1 parent 63853b1 commit fda7e52
Show file tree
Hide file tree
Showing 7 changed files with 215 additions and 17 deletions.
15 changes: 14 additions & 1 deletion faust/agents/agent.py
@@ -1,5 +1,6 @@
"""Agent implementation."""
import asyncio
import sys
import typing
from contextlib import suppress
from contextvars import ContextVar
Expand Down Expand Up @@ -658,7 +659,19 @@ async def _prepare_actor(self, aref: ActorRefT, beacon: NodeT) -> ActorRefT:
else:
# agent yields and is an AsyncIterator so we have to consume it.
coro = self._slurp(aref, aiter(aref))
task = asyncio.Task(self._execute_actor(coro, aref), loop=self.loop)
req_version = (3, 8)
cur_version = sys.version_info
if cur_version >= req_version:
task = asyncio.Task(
self._execute_actor(coro, aref),
loop=self.loop,
name=f"{str(aref)}-{self.channel.get_topic_name()}",
)
else:
task = asyncio.Task(
self._execute_actor(coro, aref),
loop=self.loop,
)
task._beacon = beacon # type: ignore
aref.actor_task = task
self._actors.add(aref)
Expand Down
7 changes: 5 additions & 2 deletions faust/app/base.py
Expand Up @@ -310,9 +310,12 @@ def sensors(self) -> Iterable[ServiceT]:

def kafka_producer(self) -> Iterable[ServiceT]:
"""Return list of services required to start Kafka producer."""
producers = []
if self._should_enable_kafka_producer():
return [self.app.producer]
return []
producers.append(self.app.producer)
if self.app.conf.producer_threaded:
producers.append(self.app.producer.threaded_producer)
return producers

def _should_enable_kafka_producer(self) -> bool:
if self.enable_kafka_producer is None:
Expand Down
2 changes: 1 addition & 1 deletion faust/stores/rocksdb.py
Expand Up @@ -367,7 +367,7 @@ async def assign_partitions(self, table: CollectionT, tps: Set[TP]) -> None:
await asyncio.sleep(0)

async def _try_open_db_for_partition(
self, partition: int, max_retries: int = 5, retry_delay: float = 1.0
self, partition: int, max_retries: int = 60, retry_delay: float = 1.0
) -> DB:
for i in range(max_retries):
try:
Expand Down
7 changes: 4 additions & 3 deletions faust/transport/consumer.py
Expand Up @@ -830,12 +830,13 @@ async def _commit_livelock_detector(self) -> None: # pragma: no cover
await self.sleep(interval)
async for sleep_time in self.itertimer(interval, name="livelock"):
if not self.app.rebalancing:
await self.verify_all_partitions_active()
await self.app.loop.run_in_executor(
None, self.verify_all_partitions_active
)

async def verify_all_partitions_active(self) -> None:
def verify_all_partitions_active(self) -> None:
now = monotonic()
for tp in self.assignment():
await self.sleep(0)
if not self.should_stop:
self.verify_event_path(now, tp)

Expand Down
151 changes: 149 additions & 2 deletions faust/transport/drivers/aiokafka.py
@@ -1,8 +1,9 @@
"""Message transport using :pypi:`aiokafka`."""
import asyncio
import typing
from asyncio import Lock
from asyncio import Lock, QueueEmpty
from collections import deque
from functools import partial
from time import monotonic
from typing import (
Any,
Expand Down Expand Up @@ -44,6 +45,7 @@
from kafka.partitioner.default import DefaultPartitioner
from kafka.protocol.metadata import MetadataRequest_v1
from mode import Service, get_logger
from mode.threads import ServiceThread, WorkerThread
from mode.utils import text
from mode.utils.futures import StampedeWrapper
from mode.utils.objects import cached_property
Expand All @@ -66,7 +68,14 @@
ThreadDelegateConsumer,
ensure_TPset,
)
from faust.types import TP, ConsumerMessage, HeadersArg, RecordMetadata
from faust.types import (
TP,
ConsumerMessage,
FutureMessage,
HeadersArg,
PendingMessage,
RecordMetadata,
)
from faust.types.auth import CredentialsT
from faust.types.transports import ConsumerT, PartitionerT, ProducerT
from faust.utils.kafka.protocol.admin import CreateTopicsRequest
Expand Down Expand Up @@ -250,6 +259,141 @@ async def on_stop(self) -> None:
transport._topic_waiters.clear()


class ThreadedProducer(ServiceThread):
_producer: Optional[aiokafka.AIOKafkaProducer] = None
event_queue: Optional[asyncio.Queue] = None
_default_producer: Optional[aiokafka.AIOKafkaProducer] = None
app: None

def __init__(
self,
default_producer,
app,
*,
executor: Any = None,
loop: asyncio.AbstractEventLoop = None,
thread_loop: asyncio.AbstractEventLoop = None,
Worker: Type[WorkerThread] = None,
**kwargs: Any,
) -> None:
super().__init__(
executor=executor,
loop=loop,
thread_loop=thread_loop,
Worker=Worker,
**kwargs,
)
self._default_producer = default_producer
self.app = app

async def flush(self) -> None:
"""Wait for producer to finish transmitting all buffered messages."""
while True:
try:
msg = self.event_queue.get_nowait()
except QueueEmpty:
break
else:
await self.publish_message(msg)
if self._producer is not None:
await self._producer.flush()

def _new_producer(self, transactional_id: str = None) -> aiokafka.AIOKafkaProducer:
return aiokafka.AIOKafkaProducer(
loop=self.thread_loop,
**{
**self._default_producer._settings_default(),
**self._default_producer._settings_auth(),
**self._default_producer._settings_extra(),
},
transactional_id=transactional_id,
)

async def on_start(self) -> None:
self.event_queue = asyncio.Queue()
producer = self._producer = self._new_producer()
await producer.start()
asyncio.create_task(self.push_events())

async def on_thread_stop(self) -> None:
"""Call when producer thread is stopping."""
logger.info("Stopping producer thread")
await super().on_thread_stop()
# when method queue is stopped, we can stop the consumer
if self._producer is not None:
await self.flush()
await self._producer.stop()

async def push_events(self):
while True:
event = await self.event_queue.get()
await self.publish_message(event)

async def publish_message(
self, fut_other: FutureMessage, wait: bool = False
) -> Awaitable[RecordMetadata]:
"""Fulfill promise to publish message to topic."""
fut = FutureMessage(fut_other.message)
message: PendingMessage = fut.message
topic = message.channel.get_topic_name()
key: bytes = cast(bytes, message.key)
value: bytes = cast(bytes, message.value)
partition: Optional[int] = message.partition
timestamp: float = cast(float, message.timestamp)
headers: Optional[HeadersArg] = message.headers
logger.debug(
"send: topic=%r k=%r v=%r timestamp=%r partition=%r",
topic,
key,
value,
timestamp,
partition,
)
producer = self._producer
state = self.app.sensors.on_send_initiated(
producer,
topic,
message=message,
keysize=len(key) if key else 0,
valsize=len(value) if value else 0,
)
timestamp_ms = int(timestamp * 1000.0) if timestamp else timestamp
if headers is not None:
if isinstance(headers, Mapping):
headers = list(headers.items())
if wait:
ret: RecordMetadata = await producer.send_and_wait(
topic=topic,
key=key,
value=value,
partition=partition,
timestamp_ms=timestamp_ms,
headers=headers,
)
return await self._finalize_message(fut, ret)
else:
fut2 = cast(
asyncio.Future,
await producer.send(
topic=topic,
key=key,
value=value,
partition=partition,
timestamp_ms=timestamp_ms,
headers=headers,
),
)
callback = partial(
fut.message.channel._on_published,
message=fut,
state=state,
producer=producer,
)
fut2.add_done_callback(cast(Callable, callback))
await fut2
return fut2


class AIOKafkaConsumerThread(ConsumerThread):
_consumer: Optional[aiokafka.AIOKafkaConsumer] = None
_pending_rebalancing_spans: Deque[opentracing.Span]
Expand Down Expand Up @@ -884,6 +1028,9 @@ class Producer(base.Producer):
_transaction_producers: typing.Dict[str, aiokafka.AIOKafkaProducer] = {}
_trn_locks: typing.Dict[str, Lock] = {}

def create_threaded_producer(self):
return ThreadedProducer(default_producer=self, app=self.app)

def __post_init__(self) -> None:
self._send_on_produce_message = self.app.on_produce_message.send
if self.partitioner is None:
Expand Down
35 changes: 27 additions & 8 deletions faust/transport/producer.py
Expand Up @@ -7,21 +7,25 @@
- Sending messages.
"""
import asyncio
import time
from asyncio import QueueEmpty
from typing import Any, Awaitable, Mapping, Optional, cast

from mode import Seconds, Service
from mode import Seconds, Service, get_logger
from mode.threads import ServiceThread

from faust.types import AppT, HeadersArg
from faust.types.transports import ProducerBufferT, ProducerT, TransportT
from faust.types.tuples import TP, FutureMessage, RecordMetadata

__all__ = ["Producer"]
logger = get_logger(__name__)


class ProducerBuffer(Service, ProducerBufferT):

app: AppT = None
max_messages = 100
queue: Optional[asyncio.Queue] = None

def __post_init__(self) -> None:
self.pending = asyncio.Queue()
Expand All @@ -32,7 +36,14 @@ def put(self, fut: FutureMessage) -> None:
The message will be eventually produced, you can await
the future to wait for that to happen.
"""
self.pending.put_nowait(fut)
if self.app.conf.producer_threaded:
if not self.queue:
self.queue = self.threaded_producer.event_queue
asyncio.run_coroutine_threadsafe(
self.queue.put(fut), self.threaded_producer.thread_loop
)
else:
self.pending.put_nowait(fut)

async def on_stop(self) -> None:
await self.flush()
Expand Down Expand Up @@ -85,7 +96,11 @@ async def wait_until_ebb(self) -> None:
is of an acceptable size before resuming stream processing flow.
"""
if self.size > self.max_messages:
logger.warning(f"producer buffer full size {self.size}")
start_time = time.time()
await self.flush_atmost(self.max_messages)
end_time = time.time()
logger.info(f"producer flush took {end_time-start_time}")

@Service.task
async def _handle_pending(self) -> None:
Expand All @@ -111,12 +126,13 @@ class Producer(Service, ProducerT):
app: AppT

_api_version: str
threaded_producer: Optional[ServiceThread] = None

def __init__(
self,
transport: TransportT,
loop: asyncio.AbstractEventLoop = None,
**kwargs: Any
**kwargs: Any,
) -> None:
self.transport = transport
self.app = self.transport.app
Expand All @@ -134,8 +150,11 @@ def __init__(
api_version = self._api_version = conf.producer_api_version
assert api_version is not None
super().__init__(loop=loop or self.transport.loop, **kwargs)

self.buffer = ProducerBuffer(loop=self.loop, beacon=self.beacon)
if conf.producer_threaded:
self.threaded_producer = self.create_threaded_producer()
self.buffer.threaded_producer = self.threaded_producer
self.buffer.app = self.app

async def on_start(self) -> None:
await self.add_runtime_dependency(self.buffer)
Expand All @@ -149,7 +168,7 @@ async def send(
timestamp: Optional[float],
headers: Optional[HeadersArg],
*,
transactional_id: str = None
transactional_id: str = None,
) -> Awaitable[RecordMetadata]:
"""Schedule message to be sent by producer."""
raise NotImplementedError()
Expand All @@ -166,7 +185,7 @@ async def send_and_wait(
timestamp: Optional[float],
headers: Optional[HeadersArg],
*,
transactional_id: str = None
transactional_id: str = None,
) -> RecordMetadata:
"""Send message and wait for it to be transmitted."""
raise NotImplementedError()
Expand All @@ -187,7 +206,7 @@ async def create_topic(
retention: Seconds = None,
compacting: bool = None,
deleting: bool = None,
ensure_created: bool = False
ensure_created: bool = False,
) -> None:
"""Create/declare topic on server."""
raise NotImplementedError()
Expand Down
15 changes: 15 additions & 0 deletions faust/types/settings/settings.py
Expand Up @@ -127,6 +127,7 @@ def __init__(
producer_max_request_size: int = None,
producer_partitioner: SymbolArg[PartitionerT] = None,
producer_request_timeout: Seconds = None,
producer_threaded: bool = False,
# RPC settings:
reply_create_topic: bool = None,
reply_expires: Seconds = None,
Expand Down Expand Up @@ -1326,6 +1327,20 @@ def producer_request_timeout(self) -> float:
producer batches expire and will no longer be retried.
"""

@sections.Producer.setting(
params.Bool,
version_introduced="0.5.0",
env_name="PRODUCER_THREADED",
default=False,
)
def producer_threaded(self) -> bool:
"""Thread separate producer for send_soon.
If True, spin up a different producer in a different thread
to be used for messages buffered up for producing via
send_soon function.
"""

@sections.RPC.setting(
params.Bool,
env_name="APP_REPLY_CREATE_TOPIC",
Expand Down

0 comments on commit fda7e52

Please sign in to comment.