Skip to content

Commit

Permalink
Validate in and range filter inputs (#1090)
Browse files Browse the repository at this point in the history
Co-authored-by: Thomas Leonard <thomas@loftorbital.com>
  • Loading branch information
tcleonard and Thomas Leonard committed Jan 10, 2021
1 parent ea84827 commit 10e48c2
Show file tree
Hide file tree
Showing 6 changed files with 209 additions and 33 deletions.
2 changes: 1 addition & 1 deletion graphene_django/filter/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
)
else:
from .fields import DjangoFilterConnectionField
from .filterset import GlobalIDFilter, GlobalIDMultipleChoiceFilter
from .filters import GlobalIDFilter, GlobalIDMultipleChoiceFilter

__all__ = [
"DjangoFilterConnectionField",
Expand Down
75 changes: 75 additions & 0 deletions graphene_django/filter/filters.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
from django.core.exceptions import ValidationError
from django.forms import Field

from django_filters import Filter, MultipleChoiceFilter

from graphql_relay.node.node import from_global_id

from ..forms import GlobalIDFormField, GlobalIDMultipleChoiceField


class GlobalIDFilter(Filter):
"""
Filter for Relay global ID.
"""

field_class = GlobalIDFormField

def filter(self, qs, value):
""" Convert the filter value to a primary key before filtering """
_id = None
if value is not None:
_, _id = from_global_id(value)
return super(GlobalIDFilter, self).filter(qs, _id)


class GlobalIDMultipleChoiceFilter(MultipleChoiceFilter):
field_class = GlobalIDMultipleChoiceField

def filter(self, qs, value):
gids = [from_global_id(v)[1] for v in value]
return super(GlobalIDMultipleChoiceFilter, self).filter(qs, gids)


class InFilter(Filter):
"""
Filter for a list of value using the `__in` Django filter.
"""

def filter(self, qs, value):
"""
Override the default filter class to check first weather the list is
empty or not.
This needs to be done as in this case we expect to get an empty output
(if not an exclude filter) but django_filter consider an empty list
to be an empty input value (see `EMPTY_VALUES`) meaning that
the filter does not need to be applied (hence returning the original
queryset).
"""
if value is not None and len(value) == 0:
if self.exclude:
return qs
else:
return qs.none()
else:
return super().filter(qs, value)


def validate_range(value):
"""
Validator for range filter input: the list of value must be of length 2.
Note that validators are only run if the value is not empty.
"""
if len(value) != 2:
raise ValidationError(
"Invalid range specified: it needs to contain 2 values.", code="invalid"
)


class RangeField(Field):
default_validators = [validate_range]
empty_values = [None]


class RangeFilter(Filter):
field_class = RangeField
23 changes: 1 addition & 22 deletions graphene_django/filter/filterset.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,28 +5,7 @@
from django_filters.filterset import BaseFilterSet, FilterSet
from django_filters.filterset import FILTER_FOR_DBFIELD_DEFAULTS

from graphql_relay.node.node import from_global_id

from ..forms import GlobalIDFormField, GlobalIDMultipleChoiceField


class GlobalIDFilter(Filter):
field_class = GlobalIDFormField

def filter(self, qs, value):
""" Convert the filter value to a primary key before filtering """
_id = None
if value is not None:
_, _id = from_global_id(value)
return super(GlobalIDFilter, self).filter(qs, _id)


class GlobalIDMultipleChoiceFilter(MultipleChoiceFilter):
field_class = GlobalIDMultipleChoiceField

def filter(self, qs, value):
gids = [from_global_id(v)[1] for v in value]
return super(GlobalIDMultipleChoiceFilter, self).filter(qs, gids)
from .filters import GlobalIDFilter, GlobalIDMultipleChoiceFilter


GRAPHENE_FILTER_SET_OVERRIDES = {
Expand Down
12 changes: 4 additions & 8 deletions graphene_django/filter/tests/test_in_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,20 +157,19 @@ def test_int_in_filter():
]


def test_int_range_filter():
def test_in_filter_with_empty_list():
"""
Test in filter on an integer field.
Check that using a in filter with an empty list provided as input returns no objects.
"""
Pet.objects.create(name="Brutus", age=12)
Pet.objects.create(name="Mimi", age=8)
Pet.objects.create(name="Jojo, the rabbit", age=3)
Pet.objects.create(name="Picotin", age=5)

schema = Schema(query=Query)

query = """
query {
pets (age_Range: [4, 9]) {
pets (name_In: []) {
edges {
node {
name
Expand All @@ -181,7 +180,4 @@ def test_int_range_filter():
"""
result = schema.execute(query)
assert not result.errors
assert result.data["pets"]["edges"] == [
{"node": {"name": "Mimi"}},
{"node": {"name": "Picotin"}},
]
assert len(result.data["pets"]["edges"]) == 0
114 changes: 114 additions & 0 deletions graphene_django/filter/tests/test_range_filter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
import json
import pytest

from django_filters import FilterSet
from django_filters import rest_framework as filters
from graphene import ObjectType, Schema
from graphene.relay import Node
from graphene_django import DjangoObjectType
from graphene_django.tests.models import Pet
from graphene_django.utils import DJANGO_FILTER_INSTALLED

pytestmark = []

if DJANGO_FILTER_INSTALLED:
from graphene_django.filter import DjangoFilterConnectionField
else:
pytestmark.append(
pytest.mark.skipif(
True, reason="django_filters not installed or not compatible"
)
)


class PetNode(DjangoObjectType):
class Meta:
model = Pet
interfaces = (Node,)
filter_fields = {
"name": ["exact", "in"],
"age": ["exact", "in", "range"],
}


class Query(ObjectType):
pets = DjangoFilterConnectionField(PetNode)


def test_int_range_filter():
"""
Test range filter on an integer field.
"""
Pet.objects.create(name="Brutus", age=12)
Pet.objects.create(name="Mimi", age=8)
Pet.objects.create(name="Jojo, the rabbit", age=3)
Pet.objects.create(name="Picotin", age=5)

schema = Schema(query=Query)

query = """
query {
pets (age_Range: [4, 9]) {
edges {
node {
name
}
}
}
}
"""
result = schema.execute(query)
assert not result.errors
assert result.data["pets"]["edges"] == [
{"node": {"name": "Mimi"}},
{"node": {"name": "Picotin"}},
]


def test_range_filter_with_invalid_input():
"""
Test range filter used with invalid inputs raise an error.
"""
Pet.objects.create(name="Brutus", age=12)
Pet.objects.create(name="Mimi", age=8)
Pet.objects.create(name="Jojo, the rabbit", age=3)
Pet.objects.create(name="Picotin", age=5)

schema = Schema(query=Query)

query = """
query ($rangeValue: [Int]) {
pets (age_Range: $rangeValue) {
edges {
node {
name
}
}
}
}
"""
expected_error = json.dumps(
{
"age__range": [
{
"message": "Invalid range specified: it needs to contain 2 values.",
"code": "invalid",
}
]
}
)

# Empty list
result = schema.execute(query, variables={"rangeValue": []})
assert len(result.errors) == 1
assert result.errors[0].message == f"['{expected_error}']"

# Only one item in the list
result = schema.execute(query, variables={"rangeValue": [1]})
assert len(result.errors) == 1
assert result.errors[0].message == f"['{expected_error}']"

# More than 2 items in the list
result = schema.execute(query, variables={"rangeValue": [1, 2, 3]})
assert len(result.errors) == 1
assert result.errors[0].message == f"['{expected_error}']"
16 changes: 14 additions & 2 deletions graphene_django/filter/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from django_filters.filters import Filter, BaseCSVFilter

from .filterset import custom_filterset_factory, setup_filterset
from .filters import InFilter, RangeFilter


def get_filtering_args_from_filterset(filterset_class, type):
Expand Down Expand Up @@ -78,9 +79,20 @@ def replace_csv_filters(filterset_class):
"""
for name, filter_field in list(filterset_class.base_filters.items()):
filter_type = filter_field.lookup_expr
if filter_type in ["in", "range"]:
if filter_type == "in":
assert isinstance(filter_field, BaseCSVFilter)
filterset_class.base_filters[name] = InFilter(
field_name=filter_field.field_name,
lookup_expr=filter_field.lookup_expr,
label=filter_field.label,
method=filter_field.method,
exclude=filter_field.exclude,
**filter_field.extra
)

if filter_type == "range":
assert isinstance(filter_field, BaseCSVFilter)
filterset_class.base_filters[name] = Filter(
filterset_class.base_filters[name] = RangeFilter(
field_name=filter_field.field_name,
lookup_expr=filter_field.lookup_expr,
label=filter_field.label,
Expand Down

0 comments on commit 10e48c2

Please sign in to comment.