Branch: master
Find file Copy path
326 lines (279 sloc) 12.2 KB
Provides generic filtering backends that can be used to filter the results
returned by list views.
from __future__ import unicode_literals
import operator
import warnings
from functools import reduce
from django.core.exceptions import ImproperlyConfigured
from django.db import models
from django.db.models.constants import LOOKUP_SEP
from django.db.models.sql.constants import ORDER_PATTERN
from django.template import loader
from django.utils import six
from django.utils.encoding import force_text
from django.utils.translation import ugettext_lazy as _
from rest_framework.compat import (
coreapi, coreschema, distinct, is_guardian_installed
from rest_framework.settings import api_settings
class BaseFilterBackend(object):
A base class from which all filter backend classes should inherit.
def filter_queryset(self, request, queryset, view):
Return a filtered queryset.
raise NotImplementedError(".filter_queryset() must be overridden.")
def get_schema_fields(self, view):
assert coreapi is not None, 'coreapi must be installed to use `get_schema_fields()`'
assert coreschema is not None, 'coreschema must be installed to use `get_schema_fields()`'
return []
class SearchFilter(BaseFilterBackend):
# The URL query parameter used for the search.
search_param = api_settings.SEARCH_PARAM
template = 'rest_framework/filters/search.html'
lookup_prefixes = {
'^': 'istartswith',
'=': 'iexact',
'@': 'search',
'$': 'iregex',
search_title = _('Search')
search_description = _('A search term.')
def get_search_fields(self, view, request):
Search fields are obtained from the view, but the request is always
passed to this method. Sub-classes can override this method to
dynamically change the search fields based on request content.
return getattr(view, 'search_fields', None)
def get_search_terms(self, request):
Search terms are set by a ?search=... query parameter,
and may be comma and/or whitespace delimited.
params = request.query_params.get(self.search_param, '')
return params.replace(',', ' ').split()
def construct_search(self, field_name):
lookup = self.lookup_prefixes.get(field_name[0])
if lookup:
field_name = field_name[1:]
lookup = 'icontains'
return LOOKUP_SEP.join([field_name, lookup])
def must_call_distinct(self, queryset, search_fields):
Return True if 'distinct()' should be used to query the given lookups.
for search_field in search_fields:
opts = queryset.model._meta
if search_field[0] in self.lookup_prefixes:
search_field = search_field[1:]
parts = search_field.split(LOOKUP_SEP)
for part in parts:
field = opts.get_field(part)
if hasattr(field, 'get_path_info'):
# This field is a relation, update opts to follow the relation
path_info = field.get_path_info()
opts = path_info[-1].to_opts
if any(path.m2m for path in path_info):
# This field is a m2m relation so we know we need to call distinct
return True
return False
def filter_queryset(self, request, queryset, view):
search_fields = self.get_search_fields(view, request)
search_terms = self.get_search_terms(request)
if not search_fields or not search_terms:
return queryset
orm_lookups = [
for search_field in search_fields
base = queryset
conditions = []
for search_term in search_terms:
queries = [
models.Q(**{orm_lookup: search_term})
for orm_lookup in orm_lookups
conditions.append(reduce(operator.or_, queries))
queryset = queryset.filter(reduce(operator.and_, conditions))
if self.must_call_distinct(queryset, search_fields):
# Filtering against a many-to-many field requires us to
# call queryset.distinct() in order to avoid duplicate items
# in the resulting queryset.
# We try to avoid this if possible, for performance reasons.
queryset = distinct(queryset, base)
return queryset
def to_html(self, request, queryset, view):
if not getattr(view, 'search_fields', None):
return ''
term = self.get_search_terms(request)
term = term[0] if term else ''
context = {
'param': self.search_param,
'term': term
template = loader.get_template(self.template)
return template.render(context)
def get_schema_fields(self, view):
assert coreapi is not None, 'coreapi must be installed to use `get_schema_fields()`'
assert coreschema is not None, 'coreschema must be installed to use `get_schema_fields()`'
return [
class OrderingFilter(BaseFilterBackend):
# The URL query parameter used for the ordering.
ordering_param = api_settings.ORDERING_PARAM
ordering_fields = None
ordering_title = _('Ordering')
ordering_description = _('Which field to use when ordering the results.')
template = 'rest_framework/filters/ordering.html'
def get_ordering(self, request, queryset, view):
Ordering is set by a comma delimited ?ordering=... query parameter.
The `ordering` query parameter can be overridden by setting
the `ordering_param` value on the OrderingFilter or by
specifying an `ORDERING_PARAM` value in the API settings.
params = request.query_params.get(self.ordering_param)
if params:
fields = [param.strip() for param in params.split(',')]
ordering = self.remove_invalid_fields(queryset, fields, view, request)
if ordering:
return ordering
# No ordering was included, or all the ordering fields were invalid
return self.get_default_ordering(view)
def get_default_ordering(self, view):
ordering = getattr(view, 'ordering', None)
if isinstance(ordering, six.string_types):
return (ordering,)
return ordering
def get_default_valid_fields(self, queryset, view, context={}):
# If `ordering_fields` is not specified, then we determine a default
# based on the serializer class, if one exists on the view.
if hasattr(view, 'get_serializer_class'):
serializer_class = view.get_serializer_class()
except AssertionError:
# Raised by the default implementation if
# no serializer_class was found
serializer_class = None
serializer_class = getattr(view, 'serializer_class', None)
if serializer_class is None:
msg = (
"Cannot use %s on a view which does not have either a "
"'serializer_class', an overriding 'get_serializer_class' "
"or 'ordering_fields' attribute."
raise ImproperlyConfigured(msg % self.__class__.__name__)
return [
(field.source.replace('.', '__') or field_name, field.label)
for field_name, field in serializer_class(context=context).fields.items()
if not getattr(field, 'write_only', False) and not field.source == '*'
def get_valid_fields(self, queryset, view, context={}):
valid_fields = getattr(view, 'ordering_fields', self.ordering_fields)
if valid_fields is None:
# Default to allowing filtering on serializer fields
return self.get_default_valid_fields(queryset, view, context)
elif valid_fields == '__all__':
# View explicitly allows filtering on any model field
valid_fields = [
(, field.verbose_name) for field in queryset.model._meta.fields
valid_fields += [
(key, key.title().split('__'))
for key in queryset.query.annotations
valid_fields = [
(item, item) if isinstance(item, six.string_types) else item
for item in valid_fields
return valid_fields
def remove_invalid_fields(self, queryset, fields, view, request):
valid_fields = [item[0] for item in self.get_valid_fields(queryset, view, {'request': request})]
return [term for term in fields if term.lstrip('-') in valid_fields and ORDER_PATTERN.match(term)]
def filter_queryset(self, request, queryset, view):
ordering = self.get_ordering(request, queryset, view)
if ordering:
return queryset.order_by(*ordering)
return queryset
def get_template_context(self, request, queryset, view):
current = self.get_ordering(request, queryset, view)
current = None if not current else current[0]
options = []
context = {
'request': request,
'current': current,
'param': self.ordering_param,
for key, label in self.get_valid_fields(queryset, view, context):
options.append((key, '%s - %s' % (label, _('ascending'))))
options.append(('-' + key, '%s - %s' % (label, _('descending'))))
context['options'] = options
return context
def to_html(self, request, queryset, view):
template = loader.get_template(self.template)
context = self.get_template_context(request, queryset, view)
return template.render(context)
def get_schema_fields(self, view):
assert coreapi is not None, 'coreapi must be installed to use `get_schema_fields()`'
assert coreschema is not None, 'coreschema must be installed to use `get_schema_fields()`'
return [
class DjangoObjectPermissionsFilter(BaseFilterBackend):
A filter backend that limits results to those where the requesting user
has read object level permissions.
def __init__(self):
"`DjangoObjectPermissionsFilter` has been deprecated and moved to "
"the 3rd-party django-rest-framework-guardian package.",
DeprecationWarning, stacklevel=2
assert is_guardian_installed(), 'Using DjangoObjectPermissionsFilter, but django-guardian is not installed'
perm_format = '%(app_label)s.view_%(model_name)s'
def filter_queryset(self, request, queryset, view):
# We want to defer this import until run-time, rather than import-time.
# See
# (Also see #1624 for why we need to make this import explicitly)
from guardian import VERSION as guardian_version
from guardian.shortcuts import get_objects_for_user
extra = {}
user = request.user
model_cls = queryset.model
kwargs = {
'app_label': model_cls._meta.app_label,
'model_name': model_cls._meta.model_name
permission = self.perm_format % kwargs
if tuple(guardian_version) >= (1, 3):
# Maintain behavior compatibility with versions prior to 1.3
extra = {'accept_global_perms': False}
extra = {}
return get_objects_for_user(user, permission, queryset, **extra)