Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions docs/docs/guides/response/pagination.md
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,10 @@ def paginate_queryset(self, queryset, pagination: Input, **params):
request = params["request"]
```

#### Async Pagination

Standard **Django Ninja** pagination classes support async. If you wish to handle async requests with a custom pagination class, you should subclass `ninja.pagination.AsyncPaginationBase` and override the `apaginate_queryset(self, queryset, request, **params)` method.

### Output attribute

By defult page items are placed to `'items'` attribute. To override this behaviour use `items_attribute`:
Expand Down
123 changes: 100 additions & 23 deletions ninja/pagination.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import inspect
from abc import ABC, abstractmethod
from functools import partial, wraps
from typing import Any, Callable, List, Optional, Tuple, Type
from typing import Any, AsyncGenerator, Callable, List, Optional, Tuple, Type, Union

from django.db.models import QuerySet
from django.http import HttpRequest
Expand All @@ -14,7 +14,11 @@
from ninja.errors import ConfigError
from ninja.operation import Operation
from ninja.signature.details import is_collection_type
from ninja.utils import contribute_operation_args, contribute_operation_callback
from ninja.utils import (
contribute_operation_args,
contribute_operation_callback,
is_async_callable,
)


class PaginationBase(ABC):
Expand Down Expand Up @@ -53,7 +57,24 @@ def _items_count(self, queryset: QuerySet) -> int:
return len(queryset)


class LimitOffsetPagination(PaginationBase):
class AsyncPaginationBase(PaginationBase):
@abstractmethod
async def apaginate_queryset(
self,
queryset: QuerySet,
pagination: Any,
**params: Any,
) -> Any:
pass # pragma: no cover

async def _aitems_count(self, queryset: QuerySet) -> int:
try:
return await queryset.all().acount()
except AttributeError:
return len(queryset)


class LimitOffsetPagination(AsyncPaginationBase):
class Input(Schema):
limit: int = Field(settings.PAGINATION_PER_PAGE, ge=1)
offset: int = Field(0, ge=0)
Expand All @@ -71,8 +92,21 @@ def paginate_queryset(
"count": self._items_count(queryset),
} # noqa: E203

async def apaginate_queryset(
self,
queryset: QuerySet,
pagination: Input,
**params: Any,
) -> Any:
offset = pagination.offset
limit: int = min(pagination.limit, settings.PAGINATION_MAX_LIMIT)
return {
"items": queryset[offset : offset + limit],
"count": await self._aitems_count(queryset),
} # noqa: E203

class PageNumberPagination(PaginationBase):

class PageNumberPagination(AsyncPaginationBase):
class Input(Schema):
page: int = Field(1, ge=1)

Expand All @@ -94,6 +128,18 @@ def paginate_queryset(
"count": self._items_count(queryset),
} # noqa: E203

async def apaginate_queryset(
self,
queryset: QuerySet,
pagination: Input,
**params: Any,
) -> Any:
offset = (pagination.page - 1) * self.page_size
return {
"items": queryset[offset : offset + self.page_size],
"count": await self._aitems_count(queryset),
} # noqa: E203


def paginate(func_or_pgn_class: Any = NOT_SET, **paginator_params: Any) -> Callable:
"""
Expand All @@ -112,7 +158,9 @@ def my_view(request):
isfunction = inspect.isfunction(func_or_pgn_class)
isnotset = func_or_pgn_class == NOT_SET

pagination_class: Type[PaginationBase] = import_string(settings.PAGINATION_CLASS)
pagination_class: Type[Union[PaginationBase, AsyncPaginationBase]] = import_string(
settings.PAGINATION_CLASS
)

if isfunction:
return _inject_pagination(func_or_pgn_class, pagination_class)
Expand All @@ -128,26 +176,55 @@ def wrapper(func: Callable) -> Any:

def _inject_pagination(
func: Callable,
paginator_class: Type[PaginationBase],
paginator_class: Type[Union[PaginationBase, AsyncPaginationBase]],
**paginator_params: Any,
) -> Callable:
paginator: PaginationBase = paginator_class(**paginator_params)

@wraps(func)
def view_with_pagination(request: HttpRequest, **kwargs: Any) -> Any:
pagination_params = kwargs.pop("ninja_pagination")
if paginator.pass_parameter:
kwargs[paginator.pass_parameter] = pagination_params

items = func(request, **kwargs)

result = paginator.paginate_queryset(
items, pagination=pagination_params, request=request, **kwargs
)
if paginator.Output: # type: ignore
result[paginator.items_attribute] = list(result[paginator.items_attribute])
# ^ forcing queryset evaluation #TODO: check why pydantic did not do it here
return result
paginator = paginator_class(**paginator_params)
if is_async_callable(func):
if not hasattr(paginator, "apaginate_queryset"):
raise ConfigError("Pagination class not configured for async requests")

@wraps(func)
async def view_with_pagination(request: HttpRequest, **kwargs: Any) -> Any:
pagination_params = kwargs.pop("ninja_pagination")
if paginator.pass_parameter:
kwargs[paginator.pass_parameter] = pagination_params

items = await func(request, **kwargs)

result = await paginator.apaginate_queryset(
items, pagination=pagination_params, request=request, **kwargs
)

async def evaluate(results: Union[List, QuerySet]) -> AsyncGenerator:
for result in results:
yield result

if paginator.Output: # type: ignore
result[paginator.items_attribute] = [
result
async for result in evaluate(result[paginator.items_attribute])
]
return result
else:

@wraps(func)
def view_with_pagination(request: HttpRequest, **kwargs: Any) -> Any:
pagination_params = kwargs.pop("ninja_pagination")
if paginator.pass_parameter:
kwargs[paginator.pass_parameter] = pagination_params

items = func(request, **kwargs)

result = paginator.paginate_queryset(
items, pagination=pagination_params, request=request, **kwargs
)
if paginator.Output: # type: ignore
result[paginator.items_attribute] = list(
result[paginator.items_attribute]
)
# ^ forcing queryset evaluation #TODO: check why pydantic did not do it here
return result

contribute_operation_args(
view_with_pagination,
Expand Down
123 changes: 123 additions & 0 deletions tests/test_pagination_async.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
import asyncio
from typing import Any, List

import pytest
from django.db.models import QuerySet

from ninja import NinjaAPI, Schema
from ninja.errors import ConfigError
from ninja.pagination import (
AsyncPaginationBase,
PageNumberPagination,
PaginationBase,
paginate,
)
from ninja.testing import TestAsyncClient

api = NinjaAPI()

ITEMS = list(range(100))


class NoAsyncPagination(PaginationBase):
# only offset param, defaults to 5 per page
class Input(Schema):
skip: int

class Output(Schema):
items: List[Any]
count: str
skip: int

def paginate_queryset(self, items, pagination: Input, **params):
skip = pagination.skip
return {
"items": items[skip : skip + 5],
"count": "many",
"skip": skip,
}


class AsyncNoOutputPagination(AsyncPaginationBase):
# Outputs items without count attribute
class Input(Schema):
skip: int

Output = None

def paginate_queryset(self, items, pagination: Input, **params):
skip = pagination.skip
return items[skip : skip + 5]

async def apaginate_queryset(self, items, pagination: Input, **params):
await asyncio.sleep(0)
skip = pagination.skip
return items[skip : skip + 5]

def _items_count(self, queryset: QuerySet) -> int:
try:
# forcing to find queryset.count instead of list.count:
return queryset.all().count()
except AttributeError:
asyncio.sleep(0)
return len(queryset)


@pytest.mark.asyncio
async def test_async_config_error():
api = NinjaAPI()

with pytest.raises(
ConfigError, match="Pagination class not configured for async requests"
):

@api.get("/items_async_undefined", response=List[int])
@paginate(NoAsyncPagination)
async def items_async_undefined(request, **kwargs):
return ITEMS


@pytest.mark.asyncio
async def test_async_custom_pagination():
api = NinjaAPI()

@api.get("/items_async", response=List[int])
@paginate(AsyncNoOutputPagination)
async def items_async(request):
return ITEMS

client = TestAsyncClient(api)

response = await client.get("/items_async?skip=10")
assert response.json() == [10, 11, 12, 13, 14]


@pytest.mark.asyncio
async def test_async_default():
api = NinjaAPI()

@api.get("/items_default", response=List[int])
@paginate # WITHOUT brackets (should use default pagination)
async def items_default(request, someparam: int = 0, **kwargs):
asyncio.sleep(0)
return ITEMS

client = TestAsyncClient(api)

response = await client.get("/items_default?limit=10")
assert response.json() == {"items": ITEMS[:10], "count": 100}


@pytest.mark.asyncio
async def test_async_page_number():
api = NinjaAPI()

@api.get("/items_page_number", response=List[Any])
@paginate(PageNumberPagination, page_size=10, pass_parameter="page_info")
async def items_page_number(request, **kwargs):
return ITEMS + [kwargs["page_info"]]

client = TestAsyncClient(api)

response = await client.get("/items_page_number?page=11")
assert response.json() == {"items": [{"page": 11}], "count": 101}
16 changes: 15 additions & 1 deletion tests/test_pagination_router.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
from typing import List

import pytest

from ninja import NinjaAPI, Schema
from ninja.pagination import RouterPaginated
from ninja.testing import TestClient
from ninja.testing import TestAsyncClient, TestClient

api = NinjaAPI(default_router=RouterPaginated())

Expand Down Expand Up @@ -62,3 +64,15 @@ def test_for_NON_list_reponse():
]
# print(parameters)
assert parameters == []


@pytest.mark.asyncio
async def test_async_pagination():
@api.get("/items_async", response=List[ItemSchema])
async def items_async(request):
return [{"id": i} for i in range(1, 51)]

client = TestAsyncClient(api)

response = await client.get("/items_async?offset=5&limit=1")
assert response.json() == {"items": [{"id": 6}], "count": 50}