Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use get_serializer_class in ordering filter #3487

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
9 changes: 6 additions & 3 deletions rest_framework/filters.py
Expand Up @@ -227,11 +227,14 @@ def get_valid_fields(self, queryset, view):

if valid_fields is None:
# Default to allowing filtering on serializer fields
serializer_class = getattr(view, 'serializer_class')
if serializer_class is None:
try:
serializer_class = view.get_serializer_class()
except AssertionError: # raised if no serializer_class was found
msg = ("Cannot use %s on a view which does not have either a "
"'serializer_class' or 'ordering_fields' attribute.")
"'serializer_class', an overriding 'get_serializer_class' "
"or 'ordering_fields' attribute.")
raise ImproperlyConfigured(msg % self.__class__.__name__)

valid_fields = [
(field.source or field_name, field.label)
for field_name, field in serializer_class().fields.items()
Expand Down
36 changes: 36 additions & 0 deletions tests/test_filters.py
Expand Up @@ -5,6 +5,7 @@
from decimal import Decimal

from django.conf.urls import url
from django.core.exceptions import ImproperlyConfigured
from django.core.urlresolvers import reverse
from django.db import models
from django.test import TestCase
Expand Down Expand Up @@ -754,6 +755,41 @@ class OrderingListView(generics.ListAPIView):

self.assertContains(response, 'verbose title')

def test_ordering_with_overridden_get_serializer_class(self):
class OrderingListView(generics.ListAPIView):
queryset = OrderingFilterModel.objects.all()
filter_backends = (filters.OrderingFilter,)
ordering = ('title',)
# note: no ordering_fields and serializer_class speficied

def get_serializer_class(self):
return OrderingFilterSerializer

view = OrderingListView.as_view()
request = factory.get('/', {'ordering': 'text'})
response = view(request)
self.assertEqual(
response.data,
[
{'id': 1, 'title': 'zyx', 'text': 'abc'},
{'id': 2, 'title': 'yxw', 'text': 'bcd'},
{'id': 3, 'title': 'xwv', 'text': 'cde'},
]
)

def test_ordering_with_improper_configuration(self):
class OrderingListView(generics.ListAPIView):
queryset = OrderingFilterModel.objects.all()
filter_backends = (filters.OrderingFilter,)
ordering = ('title',)
# note: no ordering_fields and serializer_class
# or get_serializer_class speficied

view = OrderingListView.as_view()
request = factory.get('/', {'ordering': 'text'})
with self.assertRaises(ImproperlyConfigured):
view(request)


class SensitiveOrderingFilterModel(models.Model):
username = models.CharField(max_length=20)
Expand Down