In [None]:
# | default_exp _components.producer_decorator

In [None]:
# | export

import asyncio
import functools
import logging
import random
import time
from asyncio import iscoroutinefunction  # do not use the version from inspect
from dataclasses import dataclass
from functools import partial
from inspect import Parameter
from typing import *

from aiokafka import AIOKafkaProducer
from aiokafka.errors import KafkaTimeoutError, RequestTimedOutError
from aiokafka.producer.message_accumulator import BatchBuilder
from pydantic import BaseModel

from fastkafka._components.logger import get_logger, cached_log
from fastkafka._components.meta import export
from fastkafka._components.helpers import remove_suffix

In [None]:
import asyncio
import unittest
from contextlib import asynccontextmanager, contextmanager
from itertools import product
from unittest.mock import ANY, Mock, call
from _pytest import monkeypatch

from pydantic import Field

from fastkafka._components.logger import suppress_timestamps
from fastkafka.encoder import avro_encoder, json_encoder
from fastkafka._testing.in_memory_broker import InMemoryBroker, InMemoryProducer, InMemoryConsumer

In [None]:
# | export

logger = get_logger(__name__)

In [None]:
suppress_timestamps()
logger = get_logger(__name__, level=20)
logger.info("ok")

[INFO] __main__: ok


In [None]:
# | export


BaseSubmodel = TypeVar("BaseSubmodel", bound=Union[List[BaseModel], BaseModel])
BaseSubmodel


@dataclass
@export("fastkafka")
class KafkaEvent(Generic[BaseSubmodel]):
    """
    A generic class for representing Kafka events. Based on BaseSubmodel, bound to pydantic.BaseModel

    Attributes:
        message (BaseSubmodel): The message contained in the Kafka event, can be of type pydantic.BaseModel.
        key (bytes, optional): The optional key used to identify the Kafka event.
    """

    message: BaseSubmodel
    key: Optional[bytes] = None

In [None]:
event = KafkaEvent("Some message")
assert event.message == "Some message"
assert event.key == None

event = KafkaEvent("Some message", b"123")
assert event.message == "Some message"
assert event.key == b"123"

In [None]:
# | export


def unwrap_from_kafka_event(var_type: Union[Type, Parameter]) -> Union[Type, Parameter]:
    """
    Unwraps the type from a KafkaEvent.

    Args:
        var_type: Type to unwrap.

    Returns:
        Type: Unwrapped type if the given type is a KafkaEvent, otherwise returns the same type.

    Example:
        - Input: KafkaEvent[str]
          Output: str
        - Input: int
          Output: int
    """
    if hasattr(var_type, "__origin__") and var_type.__origin__ == KafkaEvent:
        return var_type.__args__[0]  # type: ignore
    else:
        return var_type

In [None]:
assert unwrap_from_kafka_event(KafkaEvent[int]) == int
assert unwrap_from_kafka_event(int) == int

In [None]:
# | export

ProduceReturnTypes = Union[
    BaseModel, KafkaEvent[BaseModel], List[BaseModel], KafkaEvent[List[BaseModel]]
]

ProduceCallable = Union[
    Callable[..., ProduceReturnTypes], Callable[..., Awaitable[ProduceReturnTypes]]
]

In [None]:
# # | export


# def _to_json_utf8(o: Any) -> bytes:
#     """Converts to JSON and then encodes with UTF-8"""
#     if hasattr(o, "json"):
#         return o.json().encode("utf-8")  # type: ignore
#     else:
#         return json.dumps(o).encode("utf-8")

In [None]:
# assert _to_json_utf8({"a": 1, "b": [2, 3]}) == b'{"a": 1, "b": [2, 3]}'


class ExampleMsg(BaseModel):
    name: str = Field()
    age: int


# assert _to_json_utf8(ExampleMsg(name="Davor", age=12)) == b'{"name": "Davor", "age": 12}'

In [None]:
# | export


def _wrap_in_event(
    message: Union[BaseModel, List[BaseModel], KafkaEvent]
) -> KafkaEvent:
    return message if type(message) == KafkaEvent else KafkaEvent(message)

In [None]:
message = ExampleMsg(name="Davor", age=12)
wrapped = _wrap_in_event(message)

assert type(wrapped) == KafkaEvent
assert wrapped.message == message
assert wrapped.key == None

In [None]:
message = KafkaEvent(ExampleMsg(name="Davor", age=12), b"123")
wrapped = _wrap_in_event(message)

assert type(wrapped) == KafkaEvent
assert wrapped.message == message.message
assert wrapped.key == b"123"

In [None]:
# | export


def release_callback(
    fut: asyncio.Future, topic: str, wrapped_val: KafkaEvent[BaseModel]
) -> None:
    if fut.exception() is not None:
        cached_log(
            logger,
            f"release_callback(): Exception {fut.exception()=}, raised when producing {wrapped_val.message=} to {topic=}",
            level=logging.WARNING,
            timeout=1,
            log_id="release_callback()"
        )
    pass

In [None]:
# | export


async def produce_single(  # type: ignore
    producer: AIOKafkaProducer,
    topic: str,
    encoder_fn: Callable[[BaseModel], bytes],
    wrapped_val: KafkaEvent[BaseModel],
) -> None:
    """
    Sends a single message to the Kafka producer.

    Args:
        producer (AIOKafkaProducer): The Kafka producer object.
        topic (str): The topic to which the message will be sent.
        encoder_fn (Callable[[BaseModel], bytes]): The encoding function to encode the message.
        wrapped_val (KafkaEvent[BaseModel]): The wrapped Kafka event containing the message.
    """
    while True:
        try:
            fut = await producer.send(
                topic, encoder_fn(wrapped_val.message), key=wrapped_val.key
            )
            fut.add_done_callback(partial(release_callback, topic=topic, wrapped_val=wrapped_val))
            break
        except KafkaTimeoutError as e:
            logger.warning(f"produce_single(): Exception {e=} raised when producing {wrapped_val.message} to {topic=}, sleeping for 1 second and retrying..")
            await asyncio.sleep(1)

In [None]:
class FakeProducer:
    def __init__(self, return_future: asyncio.Future = None):
        self.counter = 0
        if return_future is None:
            return_future = asyncio.Future()
            return_future.set_result("Some result")

        self.return_future = return_future

    async def send(self, *args, **kwargs):
        if self.counter < 5:
            self.counter += 1
            raise KafkaTimeoutError()
        else:
            return self.return_future

    async def send_batch(self, *args, **kwargs):
        return await self.send()

    def create_batch(self):
        return unittest.mock.MagicMock()

    def add_done_callback(self, *args):
        return

    async def partitions_for(self, *args):
        return ["partition_1", "partition_2"]
    
    def start(*args, **kwargs):
        return
    
    def stop(*args, **kwargs):
        return

In [None]:
await produce_single(
    FakeProducer(),
    topic="test_topic",
    encoder_fn=json_encoder,
    wrapped_val=KafkaEvent(message=ExampleMsg(name="Davor", age=12), key=b"test"),
)



In [None]:
timeout_future = asyncio.Future()
timeout_future.set_exception(RequestTimedOutError())

await produce_single(
    FakeProducer(return_future=timeout_future),
    topic="test_topic",
    encoder_fn=json_encoder,
    wrapped_val=KafkaEvent(message=ExampleMsg(name="Davor", age=12), key=b"test"),
)



In [None]:
with InMemoryBroker() as broker:
    ProducerClass = InMemoryProducer(broker)
    producer = ProducerClass()
    await producer.start()

    await produce_single(
        producer,
        topic="test_topic",
        encoder_fn=json_encoder,
        wrapped_val=KafkaEvent(message=ExampleMsg(name="Davor", age=12), key=b"test"),
    )

    await producer.stop()

[INFO] fastkafka._testing.in_memory_broker: InMemoryBroker._patch_consumers_and_producers(): Patching consumers and producers!
[INFO] fastkafka._testing.in_memory_broker: InMemoryBroker starting
[INFO] fastkafka._testing.in_memory_broker: AIOKafkaProducer patched start() called()
[INFO] fastkafka._testing.in_memory_broker: AIOKafkaProducer patched stop() called
[INFO] fastkafka._testing.in_memory_broker: InMemoryBroker stopping


In [None]:
# | export


async def send_batch(  # type: ignore
    producer: AIOKafkaProducer, topic: str, batch: BatchBuilder, key: Optional[bytes]
) -> None:
    """
    Sends a batch of messages to the Kafka producer.

    Args:
        producer (AIOKafkaProducer): The Kafka producer object.
        topic (str): The topic to which the messages will be sent.
        batch (BatchBuilder): The batch builder object containing the messages.
        key (Optional[bytes]): The optional key used to identify the batch of messages.

    Returns:
        None
    """
    partitions = await producer.partitions_for(topic)
    if key == None:
        partition = random.choice(tuple(partitions))  # nosec
    else:
        partition = producer._partition(topic, None, None, None, key, None)
    while True:
        try:
            await producer.send_batch(batch, topic, partition=partition)
            break
        except KafkaTimeoutError as e:
            logger.warning(f"send_batch(): Exception {e} raised when producing {batch} to {topic=}, sleeping for 1 second and retrying..")
            await asyncio.sleep(1)
    


async def produce_batch(  # type: ignore
    producer: AIOKafkaProducer,
    topic: str,
    encoder_fn: Callable[[BaseModel], bytes],
    wrapped_val: KafkaEvent[List[BaseModel]],
) -> ProduceReturnTypes:
    """
    Sends a batch of messages to the Kafka producer.

    Args:
        producer (AIOKafkaProducer): The Kafka producer object.
        topic (str): The topic to which the messages will be sent.
        encoder_fn (Callable[[BaseModel], bytes]): The encoding function to encode the messages.
        wrapped_val (KafkaEvent[List[BaseModel]]): The wrapped Kafka event containing the list of messages.

    Returns:
        ProduceReturnTypes: The return value from the decorated function.
    """
    batch = producer.create_batch()

    for message in wrapped_val.message:
        metadata = batch.append(
            key=wrapped_val.key,
            value=encoder_fn(message),
            timestamp=int(time.time() * 1000),
        )
        if metadata == None:
            # send batch
            await send_batch(producer, topic, batch, wrapped_val.key)
            # create new batch
            batch = producer.create_batch()
            batch.append(
                key=None, value=encoder_fn(message), timestamp=int(time.time() * 1000)
            )

    await send_batch(producer, topic, batch, wrapped_val.key)

In [None]:
msgs = [ExampleMsg(name="Davor", age=12) for _ in range(500)]

await produce_batch(
    FakeProducer(),
    topic="test_topic",
    encoder_fn=json_encoder,
    wrapped_val=KafkaEvent(message=msgs, key=None),
)



In [None]:
msgs = [ExampleMsg(name="Davor", age=12) for _ in range(500)]
    
with InMemoryBroker() as broker:
    ProducerClass = InMemoryProducer(broker)
    producer = ProducerClass()
    await producer.start()

    await produce_batch(
        producer,
        topic="test_topic",
        encoder_fn=json_encoder,
        wrapped_val=KafkaEvent(message=msgs, key=b"test"),
    )

    await producer.stop()

[INFO] fastkafka._testing.in_memory_broker: InMemoryBroker._patch_consumers_and_producers(): Patching consumers and producers!
[INFO] fastkafka._testing.in_memory_broker: InMemoryBroker starting
[INFO] fastkafka._testing.in_memory_broker: AIOKafkaProducer patched start() called()
[INFO] fastkafka._testing.in_memory_broker: AIOKafkaProducer patched stop() called
[INFO] fastkafka._testing.in_memory_broker: InMemoryBroker stopping


In [None]:
# | export


def producer_decorator(
    producer_store: Dict[str, Any],
    func: ProduceCallable,
    topic_key: str,
    encoder_fn: Callable[[BaseModel], bytes],
) -> ProduceCallable:
    """
    Decorator for Kafka producer functions.

    Args:
        producer_store (Dict[str, Any]): Dictionary to store the Kafka producer objects.
        func (ProduceCallable): The function to be decorated.
        topic_key (str): The key used to identify the topic.
        encoder_fn (Callable[[BaseModel], bytes]): The encoding function to encode the messages.

    Returns:
        ProduceCallable: The decorated function.

    Raises:
        ValueError: If the decorated function is synchronous.
    """

    @functools.wraps(func)
    async def _produce_async(
        *args: List[Any],
        topic_key: str = topic_key,
        encoder_fn: Callable[[BaseModel], bytes] = encoder_fn,
        producer_store: Dict[str, Any] = producer_store,
        f: Callable[..., Awaitable[ProduceReturnTypes]] = func,  # type: ignore
        **kwargs: Any,
    ) -> ProduceReturnTypes:
        return_val = await f(*args, **kwargs)
        wrapped_val = _wrap_in_event(return_val)
        _, producer, _, _ = producer_store[topic_key]
        topic = remove_suffix(topic_key)

        if isinstance(wrapped_val.message, list):
            await produce_batch(producer, topic, encoder_fn, wrapped_val)
        else:
            await produce_single(producer, topic, encoder_fn, wrapped_val)
        return return_val

    if not iscoroutinefunction(func):
        raise ValueError(
            "Synchronous functions are not supported for produce operation"
        )

    return _produce_async

In [None]:
class MockMsg(BaseModel):
    name: str = "Micky Mouse"
    id: int = 123


mock_msg = MockMsg()

topic = "test_topic_1"

In [None]:
async def _f() -> None:
#     print("Mock called")
    loop = asyncio.get_running_loop()

    # Create a new Future object.
    return loop.create_future()


@contextmanager
def mock_InMemoryProducer_send() -> Generator[Mock, None, None]:
    """Mocks **send** method of **InMemoryProducer**"""
    with unittest.mock.patch(
        "fastkafka._testing.in_memory_broker.InMemoryProducer.send"
    ) as mock:
        mock.return_value = asyncio.create_task(_f())

        yield mock

In [None]:
@asynccontextmanager
async def mock_producer_send_env() -> AsyncGenerator[
    Tuple[Mock, AIOKafkaProducer], None
]:
    try:
        with mock_InMemoryProducer_send() as send_mock:
            with InMemoryBroker() as broker:
                ProducerClass = InMemoryProducer(broker)
                producer = ProducerClass()
                await producer.start()

                yield send_mock, producer
    finally:
        await producer.stop()

In [None]:
@asynccontextmanager
async def mock_producer_batch_env() -> AsyncGenerator[
    Tuple[Mock, AIOKafkaProducer], None
]:
    try:
        with unittest.mock.patch(
            "fastkafka._testing.in_memory_broker.InMemoryProducer.send_batch"
        ) as send_batch_mock, unittest.mock.patch(
            "fastkafka._testing.in_memory_broker.InMemoryProducer.create_batch"
        ) as create_batch_mock:
            batch_mock = Mock()
            create_batch_mock.return_value = batch_mock
            send_batch_mock.return_value = asyncio.create_task(_f())
            with InMemoryBroker() as broker:
                ProducerClass = InMemoryProducer(broker)
                producer = ProducerClass()
                await producer.start()

                yield batch_mock, send_batch_mock, producer
    finally:
        await producer.stop()

In [None]:
async def func_async(mock_msg: MockMsg) -> MockMsg:
    return mock_msg


def func_sync(mock_msg: MockMsg) -> MockMsg:
    return mock_msg


is_sync = False
for encoder_fn in [json_encoder, avro_encoder]:
    print(f"Testing with: {is_sync=} , {encoder_fn=}")
    async with mock_producer_send_env() as (send_mock, producer):
        test_func = producer_decorator(
            {topic: (None, producer, None, None)},
            func_sync if is_sync else func_async,
            topic,
            encoder_fn=encoder_fn,
        )

        assert iscoroutinefunction(test_func) != is_sync

        value = test_func(mock_msg) if is_sync else await test_func(mock_msg)

        send_mock.assert_called_once_with(remove_suffix(topic), encoder_fn(mock_msg), key=None)

        assert value == mock_msg

Testing with: is_sync=False , encoder_fn=<function json_encoder>
[INFO] fastkafka._testing.in_memory_broker: InMemoryBroker._patch_consumers_and_producers(): Patching consumers and producers!
[INFO] fastkafka._testing.in_memory_broker: InMemoryBroker starting
[INFO] fastkafka._testing.in_memory_broker: AIOKafkaProducer patched start() called()
[INFO] fastkafka._testing.in_memory_broker: InMemoryBroker stopping
[INFO] fastkafka._testing.in_memory_broker: AIOKafkaProducer patched stop() called
Testing with: is_sync=False , encoder_fn=<function avro_encoder>
[INFO] fastkafka._testing.in_memory_broker: InMemoryBroker._patch_consumers_and_producers(): Patching consumers and producers!
[INFO] fastkafka._testing.in_memory_broker: InMemoryBroker starting
[INFO] fastkafka._testing.in_memory_broker: AIOKafkaProducer patched start() called()
[INFO] fastkafka._testing.in_memory_broker: InMemoryBroker stopping
[INFO] fastkafka._testing.in_memory_broker: AIOKafkaProducer patched stop() called


In [None]:
test_key = b"key"


async def func_async(mock_msg: MockMsg) -> KafkaEvent[MockMsg]:
    return KafkaEvent(mock_msg, test_key)


def func_sync(mock_msg: MockMsg) -> KafkaEvent[MockMsg]:
    return KafkaEvent(mock_msg, test_key)


is_sync = False
for encoder_fn in [json_encoder, avro_encoder]:
    print(f"Testing with: {is_sync=} , {encoder_fn=}")
    async with mock_producer_send_env() as (send_mock, producer):
        test_func = producer_decorator(
            {topic: (None, producer, None, None)},
            func_sync if is_sync else func_async,
            topic,
            encoder_fn=encoder_fn,
        )

        assert iscoroutinefunction(test_func) != is_sync

        value = test_func(mock_msg) if is_sync else await test_func(mock_msg)

        send_mock.assert_called_once_with(remove_suffix(topic), encoder_fn(mock_msg), key=test_key)

        assert value == KafkaEvent(mock_msg, test_key)

Testing with: is_sync=False , encoder_fn=<function json_encoder>
[INFO] fastkafka._testing.in_memory_broker: InMemoryBroker._patch_consumers_and_producers(): Patching consumers and producers!
[INFO] fastkafka._testing.in_memory_broker: InMemoryBroker starting
[INFO] fastkafka._testing.in_memory_broker: AIOKafkaProducer patched start() called()
[INFO] fastkafka._testing.in_memory_broker: InMemoryBroker stopping
[INFO] fastkafka._testing.in_memory_broker: AIOKafkaProducer patched stop() called
Testing with: is_sync=False , encoder_fn=<function avro_encoder>
[INFO] fastkafka._testing.in_memory_broker: InMemoryBroker._patch_consumers_and_producers(): Patching consumers and producers!
[INFO] fastkafka._testing.in_memory_broker: InMemoryBroker starting
[INFO] fastkafka._testing.in_memory_broker: AIOKafkaProducer patched start() called()
[INFO] fastkafka._testing.in_memory_broker: InMemoryBroker stopping
[INFO] fastkafka._testing.in_memory_broker: AIOKafkaProducer patched stop() called


In [None]:
batch_size = 123


async def func_async(mock_msg: MockMsg) -> List[MockMsg]:
    return [mock_msg] * batch_size


def func_sync(mock_msg: MockMsg) -> List[MockMsg]:
    return [mock_msg] * batch_size


is_sync = False
for encoder_fn in [json_encoder, avro_encoder]:
    print(f"Testing with: {is_sync=} , {encoder_fn=}")
    async with mock_producer_batch_env() as (
        batch_mock,
        send_batch_mock,
        producer,
    ):
        test_func = producer_decorator(
            {topic: (None, producer, None, None)},
            func_sync if is_sync else func_async,
            topic,
            encoder_fn=encoder_fn,
        )

        assert iscoroutinefunction(test_func) != is_sync

        value = test_func(mock_msg) if is_sync else await test_func(mock_msg)

        batch_mock.append.assert_has_calls(
            [call(key=None, value=encoder_fn(mock_msg), timestamp=ANY)] * batch_size
        )
        send_batch_mock.assert_called_once_with(batch_mock, remove_suffix(topic), partition=0)

        assert value == [mock_msg] * batch_size

Testing with: is_sync=False , encoder_fn=<function json_encoder>
[INFO] fastkafka._testing.in_memory_broker: InMemoryBroker._patch_consumers_and_producers(): Patching consumers and producers!
[INFO] fastkafka._testing.in_memory_broker: InMemoryBroker starting
[INFO] fastkafka._testing.in_memory_broker: AIOKafkaProducer patched start() called()
[INFO] fastkafka._testing.in_memory_broker: InMemoryBroker stopping
[INFO] fastkafka._testing.in_memory_broker: AIOKafkaProducer patched stop() called
Testing with: is_sync=False , encoder_fn=<function avro_encoder>
[INFO] fastkafka._testing.in_memory_broker: InMemoryBroker._patch_consumers_and_producers(): Patching consumers and producers!
[INFO] fastkafka._testing.in_memory_broker: InMemoryBroker starting
[INFO] fastkafka._testing.in_memory_broker: AIOKafkaProducer patched start() called()
[INFO] fastkafka._testing.in_memory_broker: InMemoryBroker stopping
[INFO] fastkafka._testing.in_memory_broker: AIOKafkaProducer patched stop() called


In [None]:
batch_size = 123
test_key = b"key"


async def func_async(mock_msg: MockMsg) -> KafkaEvent[List[MockMsg]]:
    return KafkaEvent([mock_msg] * batch_size, test_key)


def func_sync(mock_msg: MockMsg) -> KafkaEvent[List[MockMsg]]:
    return KafkaEvent([mock_msg] * batch_size, test_key)


is_sync = False
for encoder_fn in [json_encoder, avro_encoder]:
    print(f"Testing with: {is_sync=} , {encoder_fn=}")
    async with mock_producer_batch_env() as (batch_mock, send_batch_mock, producer):
        test_func = producer_decorator(
            {topic: (None, producer, None, None)},
            func_sync if is_sync else func_async,
            topic,
            encoder_fn=encoder_fn,
        )

        assert iscoroutinefunction(test_func) != is_sync

        value = test_func(mock_msg) if is_sync else await test_func(mock_msg)

        batch_mock.append.assert_has_calls(
            [call(key=test_key, value=encoder_fn(mock_msg), timestamp=ANY)] * batch_size
        )

        send_batch_mock.assert_called_once_with(batch_mock, remove_suffix(topic), partition=0)

        assert value == KafkaEvent([mock_msg] * batch_size, test_key)

Testing with: is_sync=False , encoder_fn=<function json_encoder>
[INFO] fastkafka._testing.in_memory_broker: InMemoryBroker._patch_consumers_and_producers(): Patching consumers and producers!
[INFO] fastkafka._testing.in_memory_broker: InMemoryBroker starting
[INFO] fastkafka._testing.in_memory_broker: AIOKafkaProducer patched start() called()
[INFO] fastkafka._testing.in_memory_broker: InMemoryBroker stopping
[INFO] fastkafka._testing.in_memory_broker: AIOKafkaProducer patched stop() called
Testing with: is_sync=False , encoder_fn=<function avro_encoder>
[INFO] fastkafka._testing.in_memory_broker: InMemoryBroker._patch_consumers_and_producers(): Patching consumers and producers!
[INFO] fastkafka._testing.in_memory_broker: InMemoryBroker starting
[INFO] fastkafka._testing.in_memory_broker: AIOKafkaProducer patched start() called()
[INFO] fastkafka._testing.in_memory_broker: InMemoryBroker stopping
[INFO] fastkafka._testing.in_memory_broker: AIOKafkaProducer patched stop() called
