diff --git a/docs/docs/guides/response/pagination.md b/docs/docs/guides/response/pagination.md index 02d7eb8e5..fa3619a80 100644 --- a/docs/docs/guides/response/pagination.md +++ b/docs/docs/guides/response/pagination.md @@ -148,6 +148,10 @@ def paginate_queryset(self, queryset, pagination: Input, **params): request = params["request"] ``` +#### Async Pagination + +Standard **Django Ninja** pagination classes support async. If you wish to handle async requests with a custom pagination class, you should subclass `ninja.pagination.AsyncPaginationBase` and override the `apaginate_queryset(self, queryset, request, **params)` method. + ### Output attribute By defult page items are placed to `'items'` attribute. To override this behaviour use `items_attribute`: diff --git a/ninja/pagination.py b/ninja/pagination.py index 379b91b58..5dd4387c2 100644 --- a/ninja/pagination.py +++ b/ninja/pagination.py @@ -1,7 +1,7 @@ import inspect from abc import ABC, abstractmethod from functools import partial, wraps -from typing import Any, Callable, List, Optional, Tuple, Type +from typing import Any, AsyncGenerator, Callable, List, Optional, Tuple, Type, Union from django.db.models import QuerySet from django.http import HttpRequest @@ -14,7 +14,11 @@ from ninja.errors import ConfigError from ninja.operation import Operation from ninja.signature.details import is_collection_type -from ninja.utils import contribute_operation_args, contribute_operation_callback +from ninja.utils import ( + contribute_operation_args, + contribute_operation_callback, + is_async_callable, +) class PaginationBase(ABC): @@ -53,7 +57,24 @@ def _items_count(self, queryset: QuerySet) -> int: return len(queryset) -class LimitOffsetPagination(PaginationBase): +class AsyncPaginationBase(PaginationBase): + @abstractmethod + async def apaginate_queryset( + self, + queryset: QuerySet, + pagination: Any, + **params: Any, + ) -> Any: + pass # pragma: no cover + + async def _aitems_count(self, queryset: QuerySet) -> int: + try: + return await queryset.all().acount() + except AttributeError: + return len(queryset) + + +class LimitOffsetPagination(AsyncPaginationBase): class Input(Schema): limit: int = Field(settings.PAGINATION_PER_PAGE, ge=1) offset: int = Field(0, ge=0) @@ -71,8 +92,21 @@ def paginate_queryset( "count": self._items_count(queryset), } # noqa: E203 + async def apaginate_queryset( + self, + queryset: QuerySet, + pagination: Input, + **params: Any, + ) -> Any: + offset = pagination.offset + limit: int = min(pagination.limit, settings.PAGINATION_MAX_LIMIT) + return { + "items": queryset[offset : offset + limit], + "count": await self._aitems_count(queryset), + } # noqa: E203 -class PageNumberPagination(PaginationBase): + +class PageNumberPagination(AsyncPaginationBase): class Input(Schema): page: int = Field(1, ge=1) @@ -94,6 +128,18 @@ def paginate_queryset( "count": self._items_count(queryset), } # noqa: E203 + async def apaginate_queryset( + self, + queryset: QuerySet, + pagination: Input, + **params: Any, + ) -> Any: + offset = (pagination.page - 1) * self.page_size + return { + "items": queryset[offset : offset + self.page_size], + "count": await self._aitems_count(queryset), + } # noqa: E203 + def paginate(func_or_pgn_class: Any = NOT_SET, **paginator_params: Any) -> Callable: """ @@ -112,7 +158,9 @@ def my_view(request): isfunction = inspect.isfunction(func_or_pgn_class) isnotset = func_or_pgn_class == NOT_SET - pagination_class: Type[PaginationBase] = import_string(settings.PAGINATION_CLASS) + pagination_class: Type[Union[PaginationBase, AsyncPaginationBase]] = import_string( + settings.PAGINATION_CLASS + ) if isfunction: return _inject_pagination(func_or_pgn_class, pagination_class) @@ -128,26 +176,55 @@ def wrapper(func: Callable) -> Any: def _inject_pagination( func: Callable, - paginator_class: Type[PaginationBase], + paginator_class: Type[Union[PaginationBase, AsyncPaginationBase]], **paginator_params: Any, ) -> Callable: - paginator: PaginationBase = paginator_class(**paginator_params) - - @wraps(func) - def view_with_pagination(request: HttpRequest, **kwargs: Any) -> Any: - pagination_params = kwargs.pop("ninja_pagination") - if paginator.pass_parameter: - kwargs[paginator.pass_parameter] = pagination_params - - items = func(request, **kwargs) - - result = paginator.paginate_queryset( - items, pagination=pagination_params, request=request, **kwargs - ) - if paginator.Output: # type: ignore - result[paginator.items_attribute] = list(result[paginator.items_attribute]) - # ^ forcing queryset evaluation #TODO: check why pydantic did not do it here - return result + paginator = paginator_class(**paginator_params) + if is_async_callable(func): + if not hasattr(paginator, "apaginate_queryset"): + raise ConfigError("Pagination class not configured for async requests") + + @wraps(func) + async def view_with_pagination(request: HttpRequest, **kwargs: Any) -> Any: + pagination_params = kwargs.pop("ninja_pagination") + if paginator.pass_parameter: + kwargs[paginator.pass_parameter] = pagination_params + + items = await func(request, **kwargs) + + result = await paginator.apaginate_queryset( + items, pagination=pagination_params, request=request, **kwargs + ) + + async def evaluate(results: Union[List, QuerySet]) -> AsyncGenerator: + for result in results: + yield result + + if paginator.Output: # type: ignore + result[paginator.items_attribute] = [ + result + async for result in evaluate(result[paginator.items_attribute]) + ] + return result + else: + + @wraps(func) + def view_with_pagination(request: HttpRequest, **kwargs: Any) -> Any: + pagination_params = kwargs.pop("ninja_pagination") + if paginator.pass_parameter: + kwargs[paginator.pass_parameter] = pagination_params + + items = func(request, **kwargs) + + result = paginator.paginate_queryset( + items, pagination=pagination_params, request=request, **kwargs + ) + if paginator.Output: # type: ignore + result[paginator.items_attribute] = list( + result[paginator.items_attribute] + ) + # ^ forcing queryset evaluation #TODO: check why pydantic did not do it here + return result contribute_operation_args( view_with_pagination, diff --git a/tests/test_pagination_async.py b/tests/test_pagination_async.py new file mode 100644 index 000000000..15767cd21 --- /dev/null +++ b/tests/test_pagination_async.py @@ -0,0 +1,123 @@ +import asyncio +from typing import Any, List + +import pytest +from django.db.models import QuerySet + +from ninja import NinjaAPI, Schema +from ninja.errors import ConfigError +from ninja.pagination import ( + AsyncPaginationBase, + PageNumberPagination, + PaginationBase, + paginate, +) +from ninja.testing import TestAsyncClient + +api = NinjaAPI() + +ITEMS = list(range(100)) + + +class NoAsyncPagination(PaginationBase): + # only offset param, defaults to 5 per page + class Input(Schema): + skip: int + + class Output(Schema): + items: List[Any] + count: str + skip: int + + def paginate_queryset(self, items, pagination: Input, **params): + skip = pagination.skip + return { + "items": items[skip : skip + 5], + "count": "many", + "skip": skip, + } + + +class AsyncNoOutputPagination(AsyncPaginationBase): + # Outputs items without count attribute + class Input(Schema): + skip: int + + Output = None + + def paginate_queryset(self, items, pagination: Input, **params): + skip = pagination.skip + return items[skip : skip + 5] + + async def apaginate_queryset(self, items, pagination: Input, **params): + await asyncio.sleep(0) + skip = pagination.skip + return items[skip : skip + 5] + + def _items_count(self, queryset: QuerySet) -> int: + try: + # forcing to find queryset.count instead of list.count: + return queryset.all().count() + except AttributeError: + asyncio.sleep(0) + return len(queryset) + + +@pytest.mark.asyncio +async def test_async_config_error(): + api = NinjaAPI() + + with pytest.raises( + ConfigError, match="Pagination class not configured for async requests" + ): + + @api.get("/items_async_undefined", response=List[int]) + @paginate(NoAsyncPagination) + async def items_async_undefined(request, **kwargs): + return ITEMS + + +@pytest.mark.asyncio +async def test_async_custom_pagination(): + api = NinjaAPI() + + @api.get("/items_async", response=List[int]) + @paginate(AsyncNoOutputPagination) + async def items_async(request): + return ITEMS + + client = TestAsyncClient(api) + + response = await client.get("/items_async?skip=10") + assert response.json() == [10, 11, 12, 13, 14] + + +@pytest.mark.asyncio +async def test_async_default(): + api = NinjaAPI() + + @api.get("/items_default", response=List[int]) + @paginate # WITHOUT brackets (should use default pagination) + async def items_default(request, someparam: int = 0, **kwargs): + asyncio.sleep(0) + return ITEMS + + client = TestAsyncClient(api) + + response = await client.get("/items_default?limit=10") + assert response.json() == {"items": ITEMS[:10], "count": 100} + + +@pytest.mark.asyncio +async def test_async_page_number(): + api = NinjaAPI() + + @api.get("/items_page_number", response=List[Any]) + @paginate(PageNumberPagination, page_size=10, pass_parameter="page_info") + async def items_page_number(request, **kwargs): + return ITEMS + [kwargs["page_info"]] + + client = TestAsyncClient(api) + + response = await client.get("/items_page_number?page=11") + assert response.json() == {"items": [{"page": 11}], "count": 101} diff --git a/tests/test_pagination_router.py b/tests/test_pagination_router.py index a6230958c..33b5bd639 100644 --- a/tests/test_pagination_router.py +++ b/tests/test_pagination_router.py @@ -1,8 +1,10 @@ from typing import List +import pytest + from ninja import NinjaAPI, Schema from ninja.pagination import RouterPaginated -from ninja.testing import TestClient +from ninja.testing import TestAsyncClient, TestClient api = NinjaAPI(default_router=RouterPaginated()) @@ -62,3 +64,15 @@ def test_for_NON_list_reponse(): ] # print(parameters) assert parameters == [] + + +@pytest.mark.asyncio +async def test_async_pagination(): + @api.get("/items_async", response=List[ItemSchema]) + async def items_async(request): + return [{"id": i} for i in range(1, 51)] + + client = TestAsyncClient(api) + + response = await client.get("/items_async?offset=5&limit=1") + assert response.json() == {"items": [{"id": 6}], "count": 50}