diff --git a/docs/core-concepts/auto-retry.md b/docs/core-concepts/auto-retry.md new file mode 100644 index 0000000..6f91b79 --- /dev/null +++ b/docs/core-concepts/auto-retry.md @@ -0,0 +1,32 @@ +# Auto retry + +## What is auto retry? + +Auto retry is a feature that allows you to automatically retry a failed HTTP request. This is useful when you want to retry a request that failed due to a network error or a timeout. + +## How does it work? + +When you make a request, the request is sent to the server. If the server responds with an error, the request is retried. If the server responds with a success, the request is not retried. + +## How do I use it? + +To use auto retry, you need to decorate your endpoint method with `@retry`. + +It takes four arguments: + +- `max_retries`: The maximum number of times to retry the request. +- `backoff_factor`: The backoff factor to use when retrying the request. +- `exceptions`: The list of status codes to retry. +- `delay`: The delay between retries. + + +```python +from declarativex import retry, http, TimeoutException + + +@retry(max_retries=3, backoff_factor=0.5, exceptions=(TimeoutException,), delay=0.5) +@http("GET", "/status/500", base_url="https://httpbin.org", timeout=0.1) +def get(): + pass +``` + diff --git a/src/declarativex/__init__.py b/src/declarativex/__init__.py index 7194bae..378cac0 100644 --- a/src/declarativex/__init__.py +++ b/src/declarativex/__init__.py @@ -22,5 +22,6 @@ from .methods import http from .middlewares import Middleware from .rate_limiter import rate_limiter +from .retry import retry __version__ = "v1.0.0" diff --git a/src/declarativex/methods.py b/src/declarativex/methods.py index 91d3fa6..61fd673 100644 --- a/src/declarativex/methods.py +++ b/src/declarativex/methods.py @@ -13,14 +13,14 @@ from .executors import AsyncExecutor, SyncExecutor from .middlewares import Middleware from .models import ClientConfiguration, EndpointConfiguration -from .utils import ReturnType +from .utils import ReturnType, DECLARED_MARK def http( method: str, path: str, *, - timeout: Optional[int] = None, + timeout: Optional[float] = None, base_url: str = "", default_query_params: Optional[Dict[str, Any]] = None, default_headers: Optional[Dict[str, str]] = None, @@ -59,7 +59,7 @@ def inner(*args: Any, **kwargs: Any): endpoint_configuration=endpoint_configuration ).execute(func, *args, **kwargs) - setattr(inner, "_declarativex", True) + setattr(inner, DECLARED_MARK, True) inner.__annotations__["return"] = ( inspect.signature(func).return_annotation ) diff --git a/src/declarativex/middlewares.py b/src/declarativex/middlewares.py index 4f5bcdb..3dabfe5 100644 --- a/src/declarativex/middlewares.py +++ b/src/declarativex/middlewares.py @@ -4,14 +4,13 @@ from typing import ( Callable, TYPE_CHECKING, - TypeVar, ) +from .utils import ReturnType + if TYPE_CHECKING: from .models import RawRequest -ReturnType = TypeVar("ReturnType") - class Signature(abc.ABCMeta): expected_signature = { diff --git a/src/declarativex/rate_limiter.py b/src/declarativex/rate_limiter.py index aa0ce32..97f7c4c 100644 --- a/src/declarativex/rate_limiter.py +++ b/src/declarativex/rate_limiter.py @@ -1,11 +1,9 @@ import asyncio import time -from functools import wraps -from typing import TypeVar, Callable, Union, Awaitable +from typing import Callable, Union, Awaitable from .exceptions import RateLimitExceeded - -ReturnType = TypeVar("ReturnType") +from .utils import ReturnType, DeclaredDecorator class Bucket: @@ -30,14 +28,14 @@ def refill(self): self.last_time_token_added = 0.0 -class rate_limiter: +class rate_limiter(DeclaredDecorator): def __init__(self, max_calls: int, interval: float, reject: bool = False): self._bucket = Bucket(max_calls, interval) self._reject = reject self._loop = asyncio.get_event_loop() self._lock = asyncio.Lock() - async def decorate_async( + async def _decorate_async( self, func: Callable[..., Awaitable[ReturnType]], *args, **kwargs ) -> ReturnType: async with self._lock: @@ -71,7 +69,7 @@ async def decorate_async( # if the bucket is empty, we have to wait self._bucket.token_bucket -= 1.0 - def decorate_sync( + def _decorate_sync( self, func: Callable[..., ReturnType], *args, **kwargs ) -> ReturnType: try: @@ -99,32 +97,17 @@ def decorate_sync( # if the bucket is empty, we have to wait self._bucket.token_bucket -= 1.0 - def decorate_class(self, cls: type) -> type: - for attr_name, attr_value in cls.__dict__.items(): - if hasattr(attr_value, "_declarativex"): - setattr(cls, attr_name, self(attr_value)) - setattr(cls, "_rate_limiter_bucket", self._bucket) + def refill(self): + self._bucket.refill() + + def _decorate_class(self, cls: type) -> type: + cls = super()._decorate_class(cls) + setattr(cls, "refill", self.refill) return cls def __call__( self, func_or_class: Union[Callable[..., ReturnType], type] ) -> Union[Callable[..., ReturnType], type]: - if isinstance(func_or_class, type): - return self.decorate_class(func_or_class) - - if asyncio.iscoroutinefunction(func_or_class): - - @wraps(func_or_class) - async def inner(*args, **kwargs): - return await self.decorate_async( - func_or_class, *args, **kwargs - ) - - else: - - @wraps(func_or_class) - def inner(*args, **kwargs): - return self.decorate_sync(func_or_class, *args, **kwargs) - - setattr(inner, "_rate_limiter_bucket", self._bucket) + inner = super().__call__(func_or_class) + setattr(inner, "refill", self.refill) return inner diff --git a/src/declarativex/retry.py b/src/declarativex/retry.py new file mode 100644 index 0000000..f33ee8e --- /dev/null +++ b/src/declarativex/retry.py @@ -0,0 +1,51 @@ +import asyncio +import time +from typing import Callable + +from .utils import DeclaredDecorator + + +class retry(DeclaredDecorator): + def __init__( + self, + max_retries: int, + exceptions: tuple, + delay: float = 0.0, + backoff_factor: float = 1.0, + ): + self._max_retries = max_retries + self._delay = delay + self._backoff_factor = backoff_factor + self._exceptions = exceptions + + async def _decorate_async( + self, func: Callable, *args, **kwargs + ): + retries = 0 + current_delay = self._delay + while retries <= self._max_retries: + try: + return await func(*args, **kwargs) + except self._exceptions as e: + retries += 1 + if retries > self._max_retries: + raise e + await asyncio.sleep(current_delay) + current_delay *= self._backoff_factor + return None # pragma: no cover + + def _decorate_sync( + self, func: Callable, *args, **kwargs + ): + retries = 0 + current_delay = self._delay + while retries <= self._max_retries: + try: + return func(*args, **kwargs) + except self._exceptions as e: + retries += 1 + if retries > self._max_retries: + raise e + time.sleep(current_delay) + current_delay *= self._backoff_factor + return None # pragma: no cover diff --git a/src/declarativex/utils.py b/src/declarativex/utils.py index 1bdaa26..c9feb32 100644 --- a/src/declarativex/utils.py +++ b/src/declarativex/utils.py @@ -1,5 +1,65 @@ -from typing import TypeVar +import abc +import asyncio +from functools import wraps +from typing import TypeVar, Callable, Union +from .exceptions import MisconfiguredException ReturnType = TypeVar("ReturnType") SUPPORTED_METHODS = {"GET", "POST", "PUT", "PATCH", "DELETE"} +DECLARED_MARK = "_declarativex_declared" + + +class DeclaredDecorator(abc.ABC): + @abc.abstractmethod + async def _decorate_async( + self, func: Callable, *args, **kwargs + ): + raise NotImplementedError + + @abc.abstractmethod + def _decorate_sync( + self, func: Callable, *args, **kwargs + ): + raise NotImplementedError + + def _decorate_class(self, cls: type) -> type: + for attr_name, attr_value in cls.__dict__.items(): + if hasattr(attr_value, DECLARED_MARK): + setattr(cls, attr_name, self(attr_value)) + setattr(cls, DECLARED_MARK, True) + return cls + + @property + def mark(self) -> str: + return f"_{self.__class__.__name__}" + + def __call__( + self, func_or_cls: Union[Callable[..., ReturnType], type] + ) -> Union[Callable[..., ReturnType], type]: + if hasattr(func_or_cls, self.mark): + raise MisconfiguredException( + f"Cannot decorate function with " + f"@{self.__class__.__name__} twice" + ) + + if isinstance(func_or_cls, type): + return self._decorate_class(func_or_cls) + + if asyncio.iscoroutinefunction(func_or_cls): + + @wraps(func_or_cls) + async def inner(*args, **kwargs): + return await self._decorate_async(func_or_cls, *args, **kwargs) + + else: + + @wraps(func_or_cls) + def inner(*args, **kwargs): + return self._decorate_sync(func_or_cls, *args, **kwargs) + + setattr( + inner, DECLARED_MARK, getattr(func_or_cls, DECLARED_MARK, False) + ) + setattr(inner, self.mark, True) + return inner diff --git a/tests/test_rate_limiter.py b/tests/test_rate_limiter.py index 947b710..100500d 100644 --- a/tests/test_rate_limiter.py +++ b/tests/test_rate_limiter.py @@ -4,9 +4,16 @@ import pytest -from declarativex import rate_limiter, BaseClient, http, RateLimitExceeded +from declarativex import ( + rate_limiter, + BaseClient, + http, + RateLimitExceeded, + MisconfiguredException, +) +@rate_limiter(max_calls=1, interval=1, reject=False) class DummyClient(BaseClient): base_url = "https://reqres.in/" @@ -19,6 +26,7 @@ async def get_user(self, user_id: int) -> dict: ... +@rate_limiter(max_calls=1, interval=1, reject=False) class SyncDummyClient(BaseClient): base_url = "https://reqres.in/" @@ -31,24 +39,35 @@ def get_user(self, user_id: int) -> dict: ... -def create_client(*, client, max_calls: int, reject: bool): - client = rate_limiter(max_calls=max_calls, interval=1, reject=reject)(client)() - - assert hasattr(client, "_rate_limiter_bucket") - assert client._rate_limiter_bucket._max_calls == max_calls - assert client._rate_limiter_bucket._interval == 1 - - assert hasattr(client.get_users, "_rate_limiter_bucket") - assert client.get_users._rate_limiter_bucket._max_calls == max_calls - assert client.get_users._rate_limiter_bucket._interval == 1 - assert client.get_users._rate_limiter_bucket == client._rate_limiter_bucket +@rate_limiter(max_calls=0, interval=1, reject=True) +class RejectDummyClient(BaseClient): + base_url = "https://reqres.in/" + + @http("GET", "/api/users") + async def get_users(self) -> dict: + ... + + @http("GET", "/api/users/{user_id}") + async def get_user(self, user_id: int) -> dict: + ... - return client + +@rate_limiter(max_calls=0, interval=1, reject=True) +class RejectSyncDummyClient(BaseClient): + base_url = "https://reqres.in/" + + @http("GET", "/api/users") + def get_users(self) -> dict: + ... + + @http("GET", "/api/users/{user_id}") + def get_user(self, user_id: int) -> dict: + ... @pytest.mark.asyncio async def test_rate_limiter(): - client = create_client(client=DummyClient, max_calls=1, reject=False) + client = DummyClient() start = time.perf_counter() for i in range(2): await client.get_users() @@ -56,7 +75,7 @@ async def test_rate_limiter(): total = time.perf_counter() - start assert 3.0 < total < 3.5 - client._rate_limiter_bucket.refill() + client.refill() if sys.version_info >= (3, 10): start = time.perf_counter() @@ -68,14 +87,14 @@ async def test_rate_limiter(): @pytest.mark.asyncio async def test_rate_limiter_rejection(): - client = create_client(client=DummyClient, max_calls=0, reject=True) - client._rate_limiter_bucket.refill() + client = RejectDummyClient() + client.refill() with pytest.raises(RateLimitExceeded): await client.get_users() def test_rate_limiter_sync(): - client = create_client(client=SyncDummyClient, max_calls=1, reject=False) + client = SyncDummyClient() start = time.perf_counter() for i in range(2): client.get_users() @@ -83,7 +102,7 @@ def test_rate_limiter_sync(): total = time.perf_counter() - start assert 3.0 < total < 3.5 - client._rate_limiter_bucket.refill() + client.refill() start = time.perf_counter() for i in range(3): @@ -93,7 +112,22 @@ def test_rate_limiter_sync(): def test_rate_limiter_sync_rejection(): - client = create_client(client=SyncDummyClient, max_calls=0, reject=True) - client._rate_limiter_bucket.refill() + client = RejectSyncDummyClient() + client.refill() with pytest.raises(RateLimitExceeded): client.get_users() + + +def test_double_decoration(): + with pytest.raises(MisconfiguredException) as exc: + + @rate_limiter(max_calls=1, interval=1, reject=False) + class FooClient(BaseClient): + @rate_limiter(max_calls=1, interval=1, reject=False) + @http("GET", "/api/users") + async def get_users(self) -> dict: + ... + + assert ( + str(exc.value) == "Cannot decorate function with @rate_limiter twice" + ) diff --git a/tests/test_retry.py b/tests/test_retry.py new file mode 100644 index 0000000..708e7a6 --- /dev/null +++ b/tests/test_retry.py @@ -0,0 +1,78 @@ +import asyncio +import time +import warnings +from typing import Annotated +from unittest.mock import MagicMock + +import httpx +import pytest +from pytest_mock import MockerFixture + +from declarativex import BaseClient, http, retry, TimeoutException, Query + + +@retry(max_retries=3, delay=0.1, exceptions=(TimeoutException,)) +class DummyClient(BaseClient): + base_url = "https://reqres.in/" + + @http("GET", "/api/users", timeout=0.1) + async def get_users(self, delay: Annotated[int, Query] = 5) -> dict: + ... + + +@retry(max_retries=3, delay=0.1, exceptions=(TimeoutException,)) +class SyncDummyClient(BaseClient): + base_url = "https://reqres.in/" + + @http("GET", "/api/users", timeout=0.1) + def get_users(self, delay: Annotated[int, Query] = 5) -> dict: + ... + + +client = DummyClient() +sync_client = SyncDummyClient() + + +@pytest.mark.asyncio +async def test_retry(mocker: MockerFixture): + call = mocker.patch( + "declarativex.executors.httpx.AsyncClient.send", + side_effect=TimeoutException( + 0.1, httpx.Request("GET", "https://reqres.in/api/users") + ), + ) + sleep = mocker.patch("asyncio.sleep", MagicMock(wraps=asyncio.sleep)) + try: + await client.get_users() + except TimeoutException: + pass + else: + raise AssertionError("Expected TimeoutException") + + assert call.call_count == 4 + assert 3 == len( + [call[0][0] for call in sleep.call_args_list if call[0][0] == 0.1] + ) + + +def test_sync_retry(mocker: MockerFixture): + call = mocker.patch( + "declarativex.executors.httpx.Client.send", + side_effect=TimeoutException( + 0.1, httpx.Request("GET", "https://reqres.in/api/users") + ), + ) + sleep = mocker.patch("time.sleep", MagicMock(wraps=time.sleep)) + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + try: + sync_client.get_users() + except TimeoutException: + pass + else: + raise AssertionError("Expected TimeoutException") + + assert call.call_count == 4 + assert 3 == len( + [call[0][0] for call in sleep.call_args_list if call[0][0] == 0.1] + )