Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bug in M2M Filter #41

Closed
wants to merge 12 commits into from
29 changes: 24 additions & 5 deletions django_filters/fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,21 +2,39 @@

from django_filters.widgets import RangeWidget, LookupTypeWidget

class RangeField(forms.MultiValueField):
class BaseRangeField(forms.MultiValueField):
"""
Base abstract class for range filters. Inheriting classes must
define form_field attribure which is a form field class to be used for
validation.
"""
widget = RangeWidget

form_class = None
def __init__(self, *args, **kwargs):
fields = (
forms.DecimalField(),
forms.DecimalField(),
self.form_class(),
self.form_class(),
)
super(RangeField, self).__init__(fields, *args, **kwargs)
super(BaseRangeField, self).__init__(fields, *args, **kwargs)

def compress(self, data_list):
if data_list:
return slice(*data_list)
return None

class NumericRangeField(BaseRangeField):
form_class = forms.DecimalField

class DateRangeField(BaseRangeField):
form_class = forms.DateField

class TimeRangeField(BaseRangeField):
form_class = forms.TimeField

class RangeField(forms.MultiValueField):
"""Deprecated. Use NumericRangeField instead."""
form_class = forms.DecimalField

class LookupTypeField(forms.MultiValueField):
def __init__(self, field, lookup_choices, *args, **kwargs):
fields = (
Expand All @@ -32,3 +50,4 @@ def __init__(self, field, lookup_choices, *args, **kwargs):

def compress(self, data_list):
return data_list

31 changes: 28 additions & 3 deletions django_filters/filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,14 @@
from django.db.models.sql.constants import QUERY_TERMS
from django.utils.translation import ugettext_lazy as _

from django_filters.fields import RangeField, LookupTypeField
from django_filters.fields import NumericRangeField, DateRangeField, TimeRangeField, LookupTypeField

__all__ = [
'Filter', 'CharFilter', 'BooleanFilter', 'ChoiceFilter',
'MultipleChoiceFilter', 'DateFilter', 'DateTimeFilter', 'TimeFilter',
'ModelChoiceFilter', 'ModelMultipleChoiceFilter', 'NumberFilter',
'RangeFilter', 'DateRangeFilter', 'AllValuesFilter',
'OpenRangeNumericFilter', 'OpenRangeDateFilter', 'OpenRangeTimeFilter'
]

LOOKUP_TYPES = sorted(QUERY_TERMS.keys())
Expand Down Expand Up @@ -88,8 +89,9 @@ def filter(self, qs, value):
value = value or ()
# TODO: this is a bit of a hack, but ModelChoiceIterator doesn't have a
# __len__ method
if len(value) == len(list(self.field.choices)):
if len(value) == len(list(self.field.choices)) and len(list(self.field.choices)) > 1:
return qs

q = Q()
for v in value:
q |= Q(**{self.name: v})
Expand All @@ -114,7 +116,7 @@ class NumberFilter(Filter):
field_class = forms.DecimalField

class RangeFilter(Filter):
field_class = RangeField
field_class = NumericRangeField

def filter(self, qs, value):
if value:
Expand Down Expand Up @@ -159,3 +161,26 @@ def field(self):
qs = self.model._default_manager.distinct().order_by(self.name).values_list(self.name, flat=True)
self.extra['choices'] = [(o, o) for o in qs]
return super(AllValuesFilter, self).field

class BaseOpenRangeFilter(Filter):
"""
Abstract class similar to RangeFilter but allows open ended ranges.
Inheriting classes must define field_class attribute.
"""

def filter(self, qs, value):
if value:
if value.start:
qs = qs.filter(**{'%s__gte' % self.name: value.start})
if value.stop:
qs = qs.filter(**{'%s__lte' % self.name: value.stop})
return qs

class OpenRangeNumericFilter(BaseOpenRangeFilter):
field_class = NumericRangeField

class OpenRangeDateFilter(BaseOpenRangeFilter):
field_class = DateRangeField

class OpenRangeTimeFilter(BaseOpenRangeFilter):
field_class = TimeRangeField
15 changes: 11 additions & 4 deletions django_filters/filterset.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,16 +184,17 @@ def __new__(cls, name, bases, attrs):
models.URLField: {
'filter_class': CharFilter,
},
models.XMLField: {
'filter_class': CharFilter,
},
models.IPAddressField: {
'filter_class': CharFilter,
},
models.CommaSeparatedIntegerField: {
'filter_class': CharFilter,
},
}
if hasattr(models, "XMLField"):
FILTER_FOR_DBFIELD_DEFAULTS[models.XMLField] = {
'filter_class': CharFilter,
}

class BaseFilterSet(object):
filter_overrides = {}
Expand Down Expand Up @@ -280,7 +281,13 @@ def filter_for_field(cls, f, name):

data = filter_for_field.get(f.__class__)
if data is None:
return
# probably a derived field, inspect parents
for _class in f.__class__.__bases__:
data = filter_for_field.get(_class)
if data:
break
if data is None:
return
filter_class = data.get('filter_class')
default.update(data.get('extra', lambda f: {})(f))
if filter_class is not None:
Expand Down
23 changes: 22 additions & 1 deletion django_filters/views.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
from django.shortcuts import render_to_response
from django.template import RequestContext
from django.http import Http404

from django.core.paginator import EmptyPage

from django_filters.filterset import FilterSet

def object_filter(request, model=None, queryset=None, template_name=None, extra_context=None,
context_processors=None, filter_class=None):
context_processors=None, filter_class=None, page_length=None, page_variable="p"):
if model is None and filter_class is None:
raise TypeError("object_filter must be called with either model or filter_class")
if model is None:
Expand All @@ -25,4 +28,22 @@ def object_filter(request, model=None, queryset=None, template_name=None, extra_
if callable(v):
v = v()
c[k] = v

if page_length:
from django.core.paginator import Paginator
p = Paginator(filterset.qs,page_length)
getvars = request.GET.copy()
if page_variable in getvars:
del getvars[page_variable]

if len(getvars.keys()) > 0:
p.querystring = "&%s" % getvars.urlencode()

try:
c['paginated_filter'] = p.page(request.GET.get(page_variable,1))
except EmptyPage:
raise Http404

c['paginator'] = p

return render_to_response(template_name, c)