-
Notifications
You must be signed in to change notification settings - Fork 766
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Validate in and range filter inputs (#1090)
Co-authored-by: Thomas Leonard <thomas@loftorbital.com>
- Loading branch information
Showing
6 changed files
with
209 additions
and
33 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,75 @@ | ||
from django.core.exceptions import ValidationError | ||
from django.forms import Field | ||
|
||
from django_filters import Filter, MultipleChoiceFilter | ||
|
||
from graphql_relay.node.node import from_global_id | ||
|
||
from ..forms import GlobalIDFormField, GlobalIDMultipleChoiceField | ||
|
||
|
||
class GlobalIDFilter(Filter): | ||
""" | ||
Filter for Relay global ID. | ||
""" | ||
|
||
field_class = GlobalIDFormField | ||
|
||
def filter(self, qs, value): | ||
""" Convert the filter value to a primary key before filtering """ | ||
_id = None | ||
if value is not None: | ||
_, _id = from_global_id(value) | ||
return super(GlobalIDFilter, self).filter(qs, _id) | ||
|
||
|
||
class GlobalIDMultipleChoiceFilter(MultipleChoiceFilter): | ||
field_class = GlobalIDMultipleChoiceField | ||
|
||
def filter(self, qs, value): | ||
gids = [from_global_id(v)[1] for v in value] | ||
return super(GlobalIDMultipleChoiceFilter, self).filter(qs, gids) | ||
|
||
|
||
class InFilter(Filter): | ||
""" | ||
Filter for a list of value using the `__in` Django filter. | ||
""" | ||
|
||
def filter(self, qs, value): | ||
""" | ||
Override the default filter class to check first weather the list is | ||
empty or not. | ||
This needs to be done as in this case we expect to get an empty output | ||
(if not an exclude filter) but django_filter consider an empty list | ||
to be an empty input value (see `EMPTY_VALUES`) meaning that | ||
the filter does not need to be applied (hence returning the original | ||
queryset). | ||
""" | ||
if value is not None and len(value) == 0: | ||
if self.exclude: | ||
return qs | ||
else: | ||
return qs.none() | ||
else: | ||
return super().filter(qs, value) | ||
|
||
|
||
def validate_range(value): | ||
""" | ||
Validator for range filter input: the list of value must be of length 2. | ||
Note that validators are only run if the value is not empty. | ||
""" | ||
if len(value) != 2: | ||
raise ValidationError( | ||
"Invalid range specified: it needs to contain 2 values.", code="invalid" | ||
) | ||
|
||
|
||
class RangeField(Field): | ||
default_validators = [validate_range] | ||
empty_values = [None] | ||
|
||
|
||
class RangeFilter(Filter): | ||
field_class = RangeField |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,114 @@ | ||
import json | ||
import pytest | ||
|
||
from django_filters import FilterSet | ||
from django_filters import rest_framework as filters | ||
from graphene import ObjectType, Schema | ||
from graphene.relay import Node | ||
from graphene_django import DjangoObjectType | ||
from graphene_django.tests.models import Pet | ||
from graphene_django.utils import DJANGO_FILTER_INSTALLED | ||
|
||
pytestmark = [] | ||
|
||
if DJANGO_FILTER_INSTALLED: | ||
from graphene_django.filter import DjangoFilterConnectionField | ||
else: | ||
pytestmark.append( | ||
pytest.mark.skipif( | ||
True, reason="django_filters not installed or not compatible" | ||
) | ||
) | ||
|
||
|
||
class PetNode(DjangoObjectType): | ||
class Meta: | ||
model = Pet | ||
interfaces = (Node,) | ||
filter_fields = { | ||
"name": ["exact", "in"], | ||
"age": ["exact", "in", "range"], | ||
} | ||
|
||
|
||
class Query(ObjectType): | ||
pets = DjangoFilterConnectionField(PetNode) | ||
|
||
|
||
def test_int_range_filter(): | ||
""" | ||
Test range filter on an integer field. | ||
""" | ||
Pet.objects.create(name="Brutus", age=12) | ||
Pet.objects.create(name="Mimi", age=8) | ||
Pet.objects.create(name="Jojo, the rabbit", age=3) | ||
Pet.objects.create(name="Picotin", age=5) | ||
|
||
schema = Schema(query=Query) | ||
|
||
query = """ | ||
query { | ||
pets (age_Range: [4, 9]) { | ||
edges { | ||
node { | ||
name | ||
} | ||
} | ||
} | ||
} | ||
""" | ||
result = schema.execute(query) | ||
assert not result.errors | ||
assert result.data["pets"]["edges"] == [ | ||
{"node": {"name": "Mimi"}}, | ||
{"node": {"name": "Picotin"}}, | ||
] | ||
|
||
|
||
def test_range_filter_with_invalid_input(): | ||
""" | ||
Test range filter used with invalid inputs raise an error. | ||
""" | ||
Pet.objects.create(name="Brutus", age=12) | ||
Pet.objects.create(name="Mimi", age=8) | ||
Pet.objects.create(name="Jojo, the rabbit", age=3) | ||
Pet.objects.create(name="Picotin", age=5) | ||
|
||
schema = Schema(query=Query) | ||
|
||
query = """ | ||
query ($rangeValue: [Int]) { | ||
pets (age_Range: $rangeValue) { | ||
edges { | ||
node { | ||
name | ||
} | ||
} | ||
} | ||
} | ||
""" | ||
expected_error = json.dumps( | ||
{ | ||
"age__range": [ | ||
{ | ||
"message": "Invalid range specified: it needs to contain 2 values.", | ||
"code": "invalid", | ||
} | ||
] | ||
} | ||
) | ||
|
||
# Empty list | ||
result = schema.execute(query, variables={"rangeValue": []}) | ||
assert len(result.errors) == 1 | ||
assert result.errors[0].message == f"['{expected_error}']" | ||
|
||
# Only one item in the list | ||
result = schema.execute(query, variables={"rangeValue": [1]}) | ||
assert len(result.errors) == 1 | ||
assert result.errors[0].message == f"['{expected_error}']" | ||
|
||
# More than 2 items in the list | ||
result = schema.execute(query, variables={"rangeValue": [1, 2, 3]}) | ||
assert len(result.errors) == 1 | ||
assert result.errors[0].message == f"['{expected_error}']" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters