diff --git a/kaflow/_consumer.py b/kaflow/_consumer.py index bc7cd44..797a7ca 100644 --- a/kaflow/_consumer.py +++ b/kaflow/_consumer.py @@ -44,7 +44,6 @@ class TopicConsumerFunc: "publish_fn", "exception_handlers", "deserialization_error_handler", - "dependent", "func", "value_param_type", "value_deserializer", @@ -54,6 +53,7 @@ class TopicConsumerFunc: "sink_topics", "executor", "container_state", + "dependent", ) def __init__( @@ -90,9 +90,7 @@ def __init__( self.publish_fn = publish_fn self.exception_handlers = exception_handlers self.deserialization_error_handler = deserialization_error_handler - self.dependent = self.container.solve( - Dependent(func, scope="consumer"), scopes=Scopes - ) + self.func = func self.value_param_type = value_param_type self.value_deserializer = value_deserializer self.key_param_type = key_param_type @@ -103,6 +101,9 @@ def __init__( def prepare(self, state: ScopeState) -> None: self.container_state = state + self.dependent = self.container.solve( + Dependent(self.func, scope="consumer"), scopes=Scopes + ) def _deserialize_value(self, value: bytes) -> TopicValueKeyHeader: return _deserialize(value, self.value_param_type, self.value_deserializer) @@ -230,7 +231,7 @@ async def _publish_messages(self, message: Message) -> None: ] ) - async def _process(self, read_message: ReadMessage) -> None: + async def _process(self, read_message: ReadMessage) -> Message | None: async with self.container.enter_scope( "consumer", state=self.container_state ) as consumer_state: @@ -239,11 +240,13 @@ async def _process(self, read_message: ReadMessage) -> None: ) if message and isinstance(message, Message): await self._publish_messages(message) + return message + return None - async def consume(self, record: ConsumerRecord) -> None: + async def consume(self, record: ConsumerRecord) -> Message | None: value, key, headers, deserialized = await self._deserialize(record) if not deserialized: - return + return None message = ReadMessage( value=value, key=key, @@ -252,4 +255,4 @@ async def consume(self, record: ConsumerRecord) -> None: partition=record.partition, timestamp=record.timestamp, ) - await self._process(read_message=message) + return await self._process(read_message=message) diff --git a/kaflow/_utils/overrides.py b/kaflow/_utils/overrides.py new file mode 100644 index 0000000..81ab8b8 --- /dev/null +++ b/kaflow/_utils/overrides.py @@ -0,0 +1,72 @@ +# Taken from: https://github.com/adriangb/xpresso/blob/main/xpresso/_utils/overrides.py +from __future__ import annotations + +import contextlib +import inspect +import typing +from types import TracebackType +from typing import cast + +from di import Container +from di.api.dependencies import DependentBase +from di.api.providers import DependencyProvider +from di.dependent import Dependent +from typing_extensions import get_args + +from kaflow._utils.inspect import is_annotated_param + + +def get_type(param: inspect.Parameter) -> type: + if is_annotated_param(param): + type_ = next(iter(get_args(param.annotation))) + else: + type_ = param.annotation + return cast(type, type_) + + +class DependencyOverrideManager: + _stacks: typing.List[contextlib.ExitStack] + + def __init__(self, container: Container) -> None: + self._container = container + self._stacks = [] + + def __setitem__( + self, target: DependencyProvider, replacement: DependencyProvider + ) -> None: + def hook( + param: typing.Optional[inspect.Parameter], + dependent: DependentBase[typing.Any], + ) -> typing.Optional[DependentBase[typing.Any]]: + if not isinstance(dependent, Dependent): + return None + scope = dependent.scope + dep = Dependent( + replacement, + scope=scope, + use_cache=dependent.use_cache, + wire=dependent.wire, + ) + if param is not None and param.annotation is not param.empty: + type_ = get_type(param) + if type_ is target: + return dep + if dependent.call is not None and dependent.call is target: + return dep + return None + + cm = self._container.bind(hook) + if self._stacks: + self._stacks[-1].enter_context(cm) + + def __enter__(self) -> DependencyOverrideManager: + self._stacks.append(contextlib.ExitStack().__enter__()) + return self + + def __exit__( + self, + __exc_type: typing.Optional[typing.Type[BaseException]], + __exc_value: typing.Optional[BaseException], + __traceback: typing.Optional[TracebackType], + ) -> typing.Optional[bool]: + return self._stacks.pop().__exit__(__exc_type, __exc_value, __traceback) diff --git a/kaflow/applications.py b/kaflow/applications.py index 7f0ddaf..6675b8d 100644 --- a/kaflow/applications.py +++ b/kaflow/applications.py @@ -31,6 +31,7 @@ from kaflow._consumer import TopicConsumerFunc from kaflow._utils.asyncio import asyncify from kaflow._utils.inspect import is_not_coroutine_function +from kaflow._utils.overrides import DependencyOverrideManager from kaflow.dependencies import Scopes from kaflow.exceptions import KaflowDeserializationException from kaflow.message import Message @@ -197,6 +198,7 @@ def __init__( # di self._container = Container() self._container_state = ScopeState() + self.dependency_overrides = DependencyOverrideManager(self._container) self._loop = asyncio.get_event_loop() self._consumer: AIOKafkaConsumer | None = None @@ -260,7 +262,7 @@ def _add_topic_consumer_func( topic_processor = TopicConsumerFunc( name=topic, container=self._container, - publish_fn=self._publish, + publish_fn=lambda *args, **kwargs: self._publish(*args, **kwargs), exception_handlers=self._exception_handlers, deserialization_error_handler=self._deserialization_error_handler, func=func, @@ -487,6 +489,9 @@ async def _publish( headers=headers_, ) + def _get_consumer(self, topic: str) -> TopicConsumerFunc: + return self._consumers[topic] + async def _consuming_loop(self) -> None: if not self._consumer: raise RuntimeError( @@ -495,7 +500,8 @@ async def _consuming_loop(self) -> None: " called yet." ) async for record in self._consumer: - await self._consumers[record.topic].consume(record=record) + consumer = self._get_consumer(record.topic) + await consumer.consume(record=record) async def start(self) -> None: self._consumer = self._create_consumer() diff --git a/kaflow/testclient.py b/kaflow/testclient.py new file mode 100644 index 0000000..9494296 --- /dev/null +++ b/kaflow/testclient.py @@ -0,0 +1,64 @@ +from __future__ import annotations + +import asyncio +from functools import wraps +from time import time +from typing import TYPE_CHECKING, Any, Awaitable, Callable + +from aiokafka import ConsumerRecord + +if TYPE_CHECKING: + from kaflow.applications import Kaflow + from kaflow.message import Message + + +def intercept_publish( + func: Callable[..., Awaitable[None]] +) -> Callable[..., Awaitable[None]]: + @wraps(func) + async def wrapper(*args: Any, **kwargs: Any) -> None: + pass + + return wrapper + + +class TestClient: + """Test client for testing a `Kaflow` application.""" + + def __init__(self, app: Kaflow) -> None: + self.app = app + self.app._publish = intercept_publish(self.app._publish) # type: ignore + self._loop = asyncio.get_event_loop() + + def publish( + self, + topic: str, + value: bytes, + key: bytes | None = None, + headers: dict[str, bytes] | None = None, + partition: int = 0, + offset: int = 0, + timestamp: int | None = None, + ) -> Message | None: + if timestamp is None: + timestamp = int(time()) + record = ConsumerRecord( + topic=topic, + partition=partition, + offset=offset, + timestamp=timestamp, + timestamp_type=0, + key=key, + value=value, + checksum=0, + serialized_key_size=len(key) if key else 0, + serialized_value_size=len(value), + headers=headers, + ) + + async def _publish() -> Message | None: + consumer = self.app._get_consumer(topic) + async with self.app.lifespan(): + return await consumer.consume(record) + + return self._loop.run_until_complete(_publish())