Skip to content

Commit

Permalink
More robust default behavior on OrderingFilter (#4156)
Browse files Browse the repository at this point in the history
  • Loading branch information
tomchristie committed Jun 1, 2016
1 parent dc09eef commit fe2aede
Showing 1 changed file with 28 additions and 12 deletions.
40 changes: 28 additions & 12 deletions rest_framework/filters.py
Expand Up @@ -222,24 +222,40 @@ def get_default_ordering(self, view):
return (ordering,)
return ordering

def get_default_valid_fields(self, queryset, view):
# If `ordering_fields` is not specified, then we determine a default
# based on the serializer class, if one exists on the view.
if hasattr(view, 'get_serializer_class'):
try:
serializer_class = view.get_serializer_class()
except AssertionError:
# Raised by the default implementation if
# no serializer_class was found
serializer_class = None
else:
serializer_class = getattr(view, 'serializer_class', None)

if serializer_class is None:
msg = (
"Cannot use %s on a view which does not have either a "
"'serializer_class', an overriding 'get_serializer_class' "
"or 'ordering_fields' attribute."
)
raise ImproperlyConfigured(msg % self.__class__.__name__)

return [
(field.source or field_name, field.label)
for field_name, field in serializer_class().fields.items()
if not getattr(field, 'write_only', False) and not field.source == '*'
]

def get_valid_fields(self, queryset, view):
valid_fields = getattr(view, 'ordering_fields', self.ordering_fields)

if valid_fields is None:
# Default to allowing filtering on serializer fields
try:
serializer_class = view.get_serializer_class()
except AssertionError: # raised if no serializer_class was found
msg = ("Cannot use %s on a view which does not have either a "
"'serializer_class', an overriding 'get_serializer_class' "
"or 'ordering_fields' attribute.")
raise ImproperlyConfigured(msg % self.__class__.__name__)
return self.get_default_valid_fields(queryset, view)

valid_fields = [
(field.source or field_name, field.label)
for field_name, field in serializer_class().fields.items()
if not getattr(field, 'write_only', False) and not field.source == '*'
]
elif valid_fields == '__all__':
# View explicitly allows filtering on any model field
valid_fields = [
Expand Down

0 comments on commit fe2aede

Please sign in to comment.