Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Ensure CursorPagination respects nulls in the ordering field #8912

Merged
merged 9 commits into from
Apr 8, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
16 changes: 11 additions & 5 deletions rest_framework/pagination.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

from django.core.paginator import InvalidPage
from django.core.paginator import Paginator as DjangoPaginator
from django.db.models import Q
from django.template import loader
from django.utils.encoding import force_str
from django.utils.translation import gettext_lazy as _
Expand Down Expand Up @@ -620,7 +621,7 @@ def paginate_queryset(self, queryset, request, view=None):
queryset = queryset.order_by(*self.ordering)

# If we have a cursor with a fixed position then filter by that.
if current_position is not None:
if str(current_position) != 'None':
order = self.ordering[0]
is_reversed = order.startswith('-')
order_attr = order.lstrip('-')
Expand All @@ -631,7 +632,12 @@ def paginate_queryset(self, queryset, request, view=None):
else:
kwargs = {order_attr + '__gt': current_position}

queryset = queryset.filter(**kwargs)
filter_query = Q(**kwargs)
# If some records contain a null for the ordering field, don't lose them.
# When reverse ordering, nulls will come last and need to be included.
if (reverse and not is_reversed) or is_reversed:
filter_query |= Q(**{order_attr + '__isnull': True})
queryset = queryset.filter(filter_query)

# If we have an offset cursor then offset the entire page by that amount.
# We also always fetch an extra item in order to determine if there is a
Expand Down Expand Up @@ -704,7 +710,7 @@ def get_next_link(self):
# The item in this position and the item following it
# have different positions. We can use this position as
# our marker.
has_item_with_unique_position = True
has_item_with_unique_position = position is not None
break

# The item in this position has the same position as the item
Expand Down Expand Up @@ -757,7 +763,7 @@ def get_previous_link(self):
# The item in this position and the item following it
# have different positions. We can use this position as
# our marker.
has_item_with_unique_position = True
has_item_with_unique_position = position is not None
break

# The item in this position has the same position as the item
Expand Down Expand Up @@ -883,7 +889,7 @@ def _get_position_from_instance(self, instance, ordering):
attr = instance[field_name]
else:
attr = getattr(instance, field_name)
return str(attr)
return None if attr is None else str(attr)

def get_paginated_response(self, data):
return Response(OrderedDict([
Expand Down
134 changes: 131 additions & 3 deletions tests/test_pagination.py
Original file line number Diff line number Diff line change
Expand Up @@ -951,17 +951,24 @@ class MockQuerySet:
def __init__(self, items):
self.items = items

def filter(self, created__gt=None, created__lt=None):
def filter(self, q):
q_args = dict(q.deconstruct()[1])
if not q_args:
# django 3.0.x artifact
q_args = dict(q.deconstruct()[2])
created__gt = q_args.get('created__gt')
created__lt = q_args.get('created__lt')

if created__gt is not None:
return MockQuerySet([
item for item in self.items
if item.created > int(created__gt)
if item.created is None or item.created > int(created__gt)
])

assert created__lt is not None
return MockQuerySet([
item for item in self.items
if item.created < int(created__lt)
if item.created is None or item.created < int(created__lt)
])

def order_by(self, *ordering):
Expand Down Expand Up @@ -1080,6 +1087,127 @@ def get_pages(self, url):
return (previous, current, next, previous_url, next_url)


class NullableCursorPaginationModel(models.Model):
created = models.IntegerField(null=True)


class TestCursorPaginationWithNulls(TestCase):
"""
Unit tests for `pagination.CursorPagination` with ordering on a nullable field.
"""

def setUp(self):
class ExamplePagination(pagination.CursorPagination):
page_size = 1
ordering = 'created'

self.pagination = ExamplePagination()
data = [
None, None, 3, 4
]
for idx in data:
NullableCursorPaginationModel.objects.create(created=idx)

self.queryset = NullableCursorPaginationModel.objects.all()

get_pages = TestCursorPagination.get_pages

def test_ascending(self):
"""Test paginating one row at a time, current should go 1, 2, 3, 4, 3, 2, 1."""
(previous, current, next, previous_url, next_url) = self.get_pages('/')

assert previous is None
assert current == [None]
assert next == [None]

(previous, current, next, previous_url, next_url) = self.get_pages(next_url)

assert previous == [None]
assert current == [None]
assert next == [3]

(previous, current, next, previous_url, next_url) = self.get_pages(next_url)

assert previous == [3] # [None] paging artifact documented at https://github.com/ddelange/django-rest-framework/blob/3.14.0/rest_framework/pagination.py#L789
assert current == [3]
assert next == [4]

(previous, current, next, previous_url, next_url) = self.get_pages(next_url)

assert previous == [3]
assert current == [4]
assert next is None
assert next_url is None

(previous, current, next, previous_url, next_url) = self.get_pages(previous_url)

assert previous == [None]
assert current == [3]
assert next == [4]

(previous, current, next, previous_url, next_url) = self.get_pages(previous_url)

assert previous == [None]
assert current == [None]
assert next == [None] # [3] paging artifact documented at https://github.com/ddelange/django-rest-framework/blob/3.14.0/rest_framework/pagination.py#L731

(previous, current, next, previous_url, next_url) = self.get_pages(previous_url)

assert previous is None
assert current == [None]
assert next == [None]

def test_descending(self):
"""Test paginating one row at a time, current should go 4, 3, 2, 1, 2, 3, 4."""
self.pagination.ordering = ('-created',)
(previous, current, next, previous_url, next_url) = self.get_pages('/')

assert previous is None
assert current == [4]
assert next == [3]

(previous, current, next, previous_url, next_url) = self.get_pages(next_url)

assert previous == [None] # [4] paging artifact
assert current == [3]
assert next == [None]

(previous, current, next, previous_url, next_url) = self.get_pages(next_url)

assert previous == [None] # [3] paging artifact
assert current == [None]
assert next == [None]

(previous, current, next, previous_url, next_url) = self.get_pages(next_url)

assert previous == [None]
assert current == [None]
assert next is None
assert next_url is None

(previous, current, next, previous_url, next_url) = self.get_pages(previous_url)

assert previous == [3]
assert current == [None]
assert next == [None]

(previous, current, next, previous_url, next_url) = self.get_pages(previous_url)

assert previous == [None]
assert current == [3]
assert next == [3] # [4] paging artifact documented at https://github.com/ddelange/django-rest-framework/blob/3.14.0/rest_framework/pagination.py#L731

# skip back artifact
(previous, current, next, previous_url, next_url) = self.get_pages(previous_url)
(previous, current, next, previous_url, next_url) = self.get_pages(previous_url)

(previous, current, next, previous_url, next_url) = self.get_pages(previous_url)

assert previous is None
assert current == [4]
assert next == [3]


def test_get_displayed_page_numbers():
"""
Test our contextual page display function.
Expand Down