-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* ✨ Add `TestClient` class * ✨ Add `DependencyOverrideManager` class * ✨ Add default values for `publish` function * 🐛 Fix `serialized_key_size` when `key` is `None`
- Loading branch information
1 parent
9e12c94
commit ef9f8a7
Showing
4 changed files
with
155 additions
and
10 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,72 @@ | ||
# Taken from: https://github.com/adriangb/xpresso/blob/main/xpresso/_utils/overrides.py | ||
from __future__ import annotations | ||
|
||
import contextlib | ||
import inspect | ||
import typing | ||
from types import TracebackType | ||
from typing import cast | ||
|
||
from di import Container | ||
from di.api.dependencies import DependentBase | ||
from di.api.providers import DependencyProvider | ||
from di.dependent import Dependent | ||
from typing_extensions import get_args | ||
|
||
from kaflow._utils.inspect import is_annotated_param | ||
|
||
|
||
def get_type(param: inspect.Parameter) -> type: | ||
if is_annotated_param(param): | ||
type_ = next(iter(get_args(param.annotation))) | ||
else: | ||
type_ = param.annotation | ||
return cast(type, type_) | ||
|
||
|
||
class DependencyOverrideManager: | ||
_stacks: typing.List[contextlib.ExitStack] | ||
|
||
def __init__(self, container: Container) -> None: | ||
self._container = container | ||
self._stacks = [] | ||
|
||
def __setitem__( | ||
self, target: DependencyProvider, replacement: DependencyProvider | ||
) -> None: | ||
def hook( | ||
param: typing.Optional[inspect.Parameter], | ||
dependent: DependentBase[typing.Any], | ||
) -> typing.Optional[DependentBase[typing.Any]]: | ||
if not isinstance(dependent, Dependent): | ||
return None | ||
scope = dependent.scope | ||
dep = Dependent( | ||
replacement, | ||
scope=scope, | ||
use_cache=dependent.use_cache, | ||
wire=dependent.wire, | ||
) | ||
if param is not None and param.annotation is not param.empty: | ||
type_ = get_type(param) | ||
if type_ is target: | ||
return dep | ||
if dependent.call is not None and dependent.call is target: | ||
return dep | ||
return None | ||
|
||
cm = self._container.bind(hook) | ||
if self._stacks: | ||
self._stacks[-1].enter_context(cm) | ||
|
||
def __enter__(self) -> DependencyOverrideManager: | ||
self._stacks.append(contextlib.ExitStack().__enter__()) | ||
return self | ||
|
||
def __exit__( | ||
self, | ||
__exc_type: typing.Optional[typing.Type[BaseException]], | ||
__exc_value: typing.Optional[BaseException], | ||
__traceback: typing.Optional[TracebackType], | ||
) -> typing.Optional[bool]: | ||
return self._stacks.pop().__exit__(__exc_type, __exc_value, __traceback) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,64 @@ | ||
from __future__ import annotations | ||
|
||
import asyncio | ||
from functools import wraps | ||
from time import time | ||
from typing import TYPE_CHECKING, Any, Awaitable, Callable | ||
|
||
from aiokafka import ConsumerRecord | ||
|
||
if TYPE_CHECKING: | ||
from kaflow.applications import Kaflow | ||
from kaflow.message import Message | ||
|
||
|
||
def intercept_publish( | ||
func: Callable[..., Awaitable[None]] | ||
) -> Callable[..., Awaitable[None]]: | ||
@wraps(func) | ||
async def wrapper(*args: Any, **kwargs: Any) -> None: | ||
pass | ||
|
||
return wrapper | ||
|
||
|
||
class TestClient: | ||
"""Test client for testing a `Kaflow` application.""" | ||
|
||
def __init__(self, app: Kaflow) -> None: | ||
self.app = app | ||
self.app._publish = intercept_publish(self.app._publish) # type: ignore | ||
self._loop = asyncio.get_event_loop() | ||
|
||
def publish( | ||
self, | ||
topic: str, | ||
value: bytes, | ||
key: bytes | None = None, | ||
headers: dict[str, bytes] | None = None, | ||
partition: int = 0, | ||
offset: int = 0, | ||
timestamp: int | None = None, | ||
) -> Message | None: | ||
if timestamp is None: | ||
timestamp = int(time()) | ||
record = ConsumerRecord( | ||
topic=topic, | ||
partition=partition, | ||
offset=offset, | ||
timestamp=timestamp, | ||
timestamp_type=0, | ||
key=key, | ||
value=value, | ||
checksum=0, | ||
serialized_key_size=len(key) if key else 0, | ||
serialized_value_size=len(value), | ||
headers=headers, | ||
) | ||
|
||
async def _publish() -> Message | None: | ||
consumer = self.app._get_consumer(topic) | ||
async with self.app.lifespan(): | ||
return await consumer.consume(record) | ||
|
||
return self._loop.run_until_complete(_publish()) |