Skip to content

Commit

Permalink
✨ Add TestClient class (#28)
Browse files Browse the repository at this point in the history
* ✨ Add `TestClient` class

* ✨ Add `DependencyOverrideManager` class

* ✨ Add default values for `publish` function

* 🐛 Fix `serialized_key_size` when `key` is `None`
  • Loading branch information
gabrielmbmb committed Apr 17, 2023
1 parent 9e12c94 commit ef9f8a7
Show file tree
Hide file tree
Showing 4 changed files with 155 additions and 10 deletions.
19 changes: 11 additions & 8 deletions kaflow/_consumer.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,6 @@ class TopicConsumerFunc:
"publish_fn",
"exception_handlers",
"deserialization_error_handler",
"dependent",
"func",
"value_param_type",
"value_deserializer",
Expand All @@ -54,6 +53,7 @@ class TopicConsumerFunc:
"sink_topics",
"executor",
"container_state",
"dependent",
)

def __init__(
Expand Down Expand Up @@ -90,9 +90,7 @@ def __init__(
self.publish_fn = publish_fn
self.exception_handlers = exception_handlers
self.deserialization_error_handler = deserialization_error_handler
self.dependent = self.container.solve(
Dependent(func, scope="consumer"), scopes=Scopes
)
self.func = func
self.value_param_type = value_param_type
self.value_deserializer = value_deserializer
self.key_param_type = key_param_type
Expand All @@ -103,6 +101,9 @@ def __init__(

def prepare(self, state: ScopeState) -> None:
self.container_state = state
self.dependent = self.container.solve(
Dependent(self.func, scope="consumer"), scopes=Scopes
)

def _deserialize_value(self, value: bytes) -> TopicValueKeyHeader:
return _deserialize(value, self.value_param_type, self.value_deserializer)
Expand Down Expand Up @@ -230,7 +231,7 @@ async def _publish_messages(self, message: Message) -> None:
]
)

async def _process(self, read_message: ReadMessage) -> None:
async def _process(self, read_message: ReadMessage) -> Message | None:
async with self.container.enter_scope(
"consumer", state=self.container_state
) as consumer_state:
Expand All @@ -239,11 +240,13 @@ async def _process(self, read_message: ReadMessage) -> None:
)
if message and isinstance(message, Message):
await self._publish_messages(message)
return message
return None

async def consume(self, record: ConsumerRecord) -> None:
async def consume(self, record: ConsumerRecord) -> Message | None:
value, key, headers, deserialized = await self._deserialize(record)
if not deserialized:
return
return None
message = ReadMessage(
value=value,
key=key,
Expand All @@ -252,4 +255,4 @@ async def consume(self, record: ConsumerRecord) -> None:
partition=record.partition,
timestamp=record.timestamp,
)
await self._process(read_message=message)
return await self._process(read_message=message)
72 changes: 72 additions & 0 deletions kaflow/_utils/overrides.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
# Taken from: https://github.com/adriangb/xpresso/blob/main/xpresso/_utils/overrides.py
from __future__ import annotations

import contextlib
import inspect
import typing
from types import TracebackType
from typing import cast

from di import Container
from di.api.dependencies import DependentBase
from di.api.providers import DependencyProvider
from di.dependent import Dependent
from typing_extensions import get_args

from kaflow._utils.inspect import is_annotated_param


def get_type(param: inspect.Parameter) -> type:
if is_annotated_param(param):
type_ = next(iter(get_args(param.annotation)))
else:
type_ = param.annotation
return cast(type, type_)


class DependencyOverrideManager:
_stacks: typing.List[contextlib.ExitStack]

def __init__(self, container: Container) -> None:
self._container = container
self._stacks = []

def __setitem__(
self, target: DependencyProvider, replacement: DependencyProvider
) -> None:
def hook(
param: typing.Optional[inspect.Parameter],
dependent: DependentBase[typing.Any],
) -> typing.Optional[DependentBase[typing.Any]]:
if not isinstance(dependent, Dependent):
return None
scope = dependent.scope
dep = Dependent(
replacement,
scope=scope,
use_cache=dependent.use_cache,
wire=dependent.wire,
)
if param is not None and param.annotation is not param.empty:
type_ = get_type(param)
if type_ is target:
return dep
if dependent.call is not None and dependent.call is target:
return dep
return None

cm = self._container.bind(hook)
if self._stacks:
self._stacks[-1].enter_context(cm)

def __enter__(self) -> DependencyOverrideManager:
self._stacks.append(contextlib.ExitStack().__enter__())
return self

def __exit__(
self,
__exc_type: typing.Optional[typing.Type[BaseException]],
__exc_value: typing.Optional[BaseException],
__traceback: typing.Optional[TracebackType],
) -> typing.Optional[bool]:
return self._stacks.pop().__exit__(__exc_type, __exc_value, __traceback)
10 changes: 8 additions & 2 deletions kaflow/applications.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from kaflow._consumer import TopicConsumerFunc
from kaflow._utils.asyncio import asyncify
from kaflow._utils.inspect import is_not_coroutine_function
from kaflow._utils.overrides import DependencyOverrideManager
from kaflow.dependencies import Scopes
from kaflow.exceptions import KaflowDeserializationException
from kaflow.message import Message
Expand Down Expand Up @@ -197,6 +198,7 @@ def __init__(
# di
self._container = Container()
self._container_state = ScopeState()
self.dependency_overrides = DependencyOverrideManager(self._container)

self._loop = asyncio.get_event_loop()
self._consumer: AIOKafkaConsumer | None = None
Expand Down Expand Up @@ -260,7 +262,7 @@ def _add_topic_consumer_func(
topic_processor = TopicConsumerFunc(
name=topic,
container=self._container,
publish_fn=self._publish,
publish_fn=lambda *args, **kwargs: self._publish(*args, **kwargs),
exception_handlers=self._exception_handlers,
deserialization_error_handler=self._deserialization_error_handler,
func=func,
Expand Down Expand Up @@ -487,6 +489,9 @@ async def _publish(
headers=headers_,
)

def _get_consumer(self, topic: str) -> TopicConsumerFunc:
return self._consumers[topic]

async def _consuming_loop(self) -> None:
if not self._consumer:
raise RuntimeError(
Expand All @@ -495,7 +500,8 @@ async def _consuming_loop(self) -> None:
" called yet."
)
async for record in self._consumer:
await self._consumers[record.topic].consume(record=record)
consumer = self._get_consumer(record.topic)
await consumer.consume(record=record)

async def start(self) -> None:
self._consumer = self._create_consumer()
Expand Down
64 changes: 64 additions & 0 deletions kaflow/testclient.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
from __future__ import annotations

import asyncio
from functools import wraps
from time import time
from typing import TYPE_CHECKING, Any, Awaitable, Callable

from aiokafka import ConsumerRecord

if TYPE_CHECKING:
from kaflow.applications import Kaflow
from kaflow.message import Message


def intercept_publish(
func: Callable[..., Awaitable[None]]
) -> Callable[..., Awaitable[None]]:
@wraps(func)
async def wrapper(*args: Any, **kwargs: Any) -> None:
pass

return wrapper


class TestClient:
"""Test client for testing a `Kaflow` application."""

def __init__(self, app: Kaflow) -> None:
self.app = app
self.app._publish = intercept_publish(self.app._publish) # type: ignore
self._loop = asyncio.get_event_loop()

def publish(
self,
topic: str,
value: bytes,
key: bytes | None = None,
headers: dict[str, bytes] | None = None,
partition: int = 0,
offset: int = 0,
timestamp: int | None = None,
) -> Message | None:
if timestamp is None:
timestamp = int(time())
record = ConsumerRecord(
topic=topic,
partition=partition,
offset=offset,
timestamp=timestamp,
timestamp_type=0,
key=key,
value=value,
checksum=0,
serialized_key_size=len(key) if key else 0,
serialized_value_size=len(value),
headers=headers,
)

async def _publish() -> Message | None:
consumer = self.app._get_consumer(topic)
async with self.app.lifespan():
return await consumer.consume(record)

return self._loop.run_until_complete(_publish())

0 comments on commit ef9f8a7

Please sign in to comment.