Skip to content

Commit

Permalink
Add enum support to filters and fix filter typing (v3) (#1119)
Browse files Browse the repository at this point in the history
* - Add filtering support for choice fields converted to graphql Enum (or not)
- Fix type of various filters (used to default to String)
- Fix bug with contains introduced in previous PR
- Fix bug with declared filters being overridden (see PR #1108)
- Fix support for ArrayField and add documentation

* Fix for v3

Co-authored-by: Thomas Leonard <thomas@loftorbital.com>
  • Loading branch information
tcleonard and Thomas Leonard authored Feb 23, 2021
1 parent 5ce4553 commit 2d4ca0a
Show file tree
Hide file tree
Showing 18 changed files with 912 additions and 122 deletions.
43 changes: 43 additions & 0 deletions docs/filtering.rst
Original file line number Diff line number Diff line change
Expand Up @@ -258,3 +258,46 @@ with this set up, you can now order the users under group:
}
}
}
PostgreSQL `ArrayField`
-----------------------

Graphene provides an easy to implement filters on `ArrayField` as they are not natively supported by django_filters:

.. code:: python
from django.db import models
from django_filters import FilterSet, OrderingFilter
from graphene_django.filter import ArrayFilter
class Event(models.Model):
name = models.CharField(max_length=50)
tags = ArrayField(models.CharField(max_length=50))
class EventFilterSet(FilterSet):
class Meta:
model = Event
fields = {
"name": ["exact", "contains"],
}
tags__contains = ArrayFilter(field_name="tags", lookup_expr="contains")
tags__overlap = ArrayFilter(field_name="tags", lookup_expr="overlap")
tags = ArrayFilter(field_name="tags", lookup_expr="exact")
class EventType(DjangoObjectType):
class Meta:
model = Event
interfaces = (Node,)
filterset_class = EventFilterSet
with this set up, you can now filter events by tags:

.. code::
query {
events(tags_Overlap: ["concert", "festival"]) {
name
}
}
2 changes: 1 addition & 1 deletion graphene_django/converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@


class BlankValueField(Field):
def get_resolver(self, parent_resolver):
def wrap_resolve(self, parent_resolver):
resolver = self.resolver or parent_resolver

# create custom resolver
Expand Down
11 changes: 10 additions & 1 deletion graphene_django/filter/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,19 @@
)
else:
from .fields import DjangoFilterConnectionField
from .filters import GlobalIDFilter, GlobalIDMultipleChoiceFilter
from .filters import (
ArrayFilter,
GlobalIDFilter,
GlobalIDMultipleChoiceFilter,
ListFilter,
RangeFilter,
)

__all__ = [
"DjangoFilterConnectionField",
"GlobalIDFilter",
"GlobalIDMultipleChoiceFilter",
"ArrayFilter",
"ListFilter",
"RangeFilter",
]
27 changes: 23 additions & 4 deletions graphene_django/filter/fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,31 @@
from functools import partial

from django.core.exceptions import ValidationError

from graphene.types.enum import EnumType
from graphene.types.argument import to_arguments
from graphene.utils.str_converters import to_snake_case

from ..fields import DjangoConnectionField
from .utils import get_filtering_args_from_filterset, get_filterset_class


def convert_enum(data):
"""
Check if the data is a enum option (or potentially nested list of enum option)
and convert it to its value.
This method is used to pre-process the data for the filters as they can take an
graphene.Enum as argument, but filters (from django_filters) expect a simple value.
"""
if isinstance(data, list):
return [convert_enum(item) for item in data]
if isinstance(type(data), EnumType):
return data.value
else:
return data


class DjangoFilterConnectionField(DjangoConnectionField):
def __init__(
self,
Expand Down Expand Up @@ -43,8 +62,8 @@ def filterset_class(self):
if self._extra_filter_meta:
meta.update(self._extra_filter_meta)

filterset_class = self._provided_filterset_class or (
self.node_type._meta.filterset_class
filterset_class = (
self._provided_filterset_class or self.node_type._meta.filterset_class
)
self._filterset_class = get_filterset_class(filterset_class, **meta)

Expand All @@ -68,7 +87,7 @@ def filter_kwargs():
if k in filtering_args:
if k == "order_by" and v is not None:
v = to_snake_case(v)
kwargs[k] = v
kwargs[k] = convert_enum(v)
return kwargs

qs = super(DjangoFilterConnectionField, cls).resolve_queryset(
Expand All @@ -78,7 +97,7 @@ def filter_kwargs():
filterset = filterset_class(
data=filter_kwargs(), queryset=qs, request=info.context
)
if filterset.form.is_valid():
if filterset.is_valid():
return filterset.qs
raise ValidationError(filterset.form.errors.as_json())

Expand Down
32 changes: 29 additions & 3 deletions graphene_django/filter/filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from django.forms import Field

from django_filters import Filter, MultipleChoiceFilter
from django_filters.constants import EMPTY_VALUES

from graphql_relay.node.node import from_global_id

Expand Down Expand Up @@ -31,14 +32,15 @@ def filter(self, qs, value):
return super(GlobalIDMultipleChoiceFilter, self).filter(qs, gids)


class InFilter(Filter):
class ListFilter(Filter):
"""
Filter for a list of value using the `__in` Django filter.
Filter that takes a list of value as input.
It is for example used for `__in` filters.
"""

def filter(self, qs, value):
"""
Override the default filter class to check first weather the list is
Override the default filter class to check first whether the list is
empty or not.
This needs to be done as in this case we expect to get an empty output
(if not an exclude filter) but django_filter consider an empty list
Expand Down Expand Up @@ -73,3 +75,27 @@ class RangeField(Field):

class RangeFilter(Filter):
field_class = RangeField


class ArrayFilter(Filter):
"""
Filter made for PostgreSQL ArrayField.
"""

def filter(self, qs, value):
"""
Override the default filter class to check first whether the list is
empty or not.
This needs to be done as in this case we expect to get the filter applied with
an empty list since it's a valid value but django_filter consider an empty list
to be an empty input value (see `EMPTY_VALUES`) meaning that
the filter does not need to be applied (hence returning the original
queryset).
"""
if value in EMPTY_VALUES and value != []:
return qs
if self.distinct:
qs = qs.distinct()
lookup = "%s__%s" % (self.field_name, self.lookup_expr)
qs = self.get_method(qs)(**{lookup: value})
return qs
59 changes: 41 additions & 18 deletions graphene_django/filter/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from graphene.relay import Node
from graphene_django import DjangoObjectType
from graphene_django.utils import DJANGO_FILTER_INSTALLED
from graphene_django.filter import ArrayFilter, ListFilter

from ...compat import ArrayField

Expand All @@ -27,49 +28,61 @@
STORE = {"events": []}


@pytest.fixture
def Event():
class Event(models.Model):
name = models.CharField(max_length=50)
tags = ArrayField(models.CharField(max_length=50))

return Event
class Event(models.Model):
name = models.CharField(max_length=50)
tags = ArrayField(models.CharField(max_length=50))
tag_ids = ArrayField(models.IntegerField())
random_field = ArrayField(models.BooleanField())


@pytest.fixture
def EventFilterSet(Event):

from django.contrib.postgres.forms import SimpleArrayField

class ArrayFilter(filters.Filter):
base_field_class = SimpleArrayField

def EventFilterSet():
class EventFilterSet(FilterSet):
class Meta:
model = Event
fields = {
"name": ["exact"],
"name": ["exact", "contains"],
}

# Those are actually usable with our Query fixture bellow
tags__contains = ArrayFilter(field_name="tags", lookup_expr="contains")
tags__overlap = ArrayFilter(field_name="tags", lookup_expr="overlap")
tags = ArrayFilter(field_name="tags", lookup_expr="exact")

# Those are actually not usable and only to check type declarations
tags_ids__contains = ArrayFilter(field_name="tag_ids", lookup_expr="contains")
tags_ids__overlap = ArrayFilter(field_name="tag_ids", lookup_expr="overlap")
tags_ids = ArrayFilter(field_name="tag_ids", lookup_expr="exact")
random_field__contains = ArrayFilter(
field_name="random_field", lookup_expr="contains"
)
random_field__overlap = ArrayFilter(
field_name="random_field", lookup_expr="overlap"
)
random_field = ArrayFilter(field_name="random_field", lookup_expr="exact")

return EventFilterSet


@pytest.fixture
def EventType(Event, EventFilterSet):
def EventType(EventFilterSet):
class EventType(DjangoObjectType):
class Meta:
model = Event
interfaces = (Node,)
fields = "__all__"
filterset_class = EventFilterSet

return EventType


@pytest.fixture
def Query(Event, EventType):
def Query(EventType):
"""
Note that we have to use a custom resolver to replicate the arrayfield filter behavior as
we are running unit tests in sqlite which does not have ArrayFields.
"""

class Query(graphene.ObjectType):
events = DjangoFilterConnectionField(EventType)

Expand All @@ -79,6 +92,7 @@ def resolve_events(self, info, **kwargs):
Event(name="Live Show", tags=["concert", "music", "rock"],),
Event(name="Musical", tags=["movie", "music"],),
Event(name="Ballet", tags=["concert", "dance"],),
Event(name="Speech", tags=[],),
]

STORE["events"] = events
Expand All @@ -105,6 +119,13 @@ def filter_events(**kwargs):
STORE["events"],
)
)
if "tags__exact" in kwargs:
STORE["events"] = list(
filter(
lambda e: set(kwargs["tags__exact"]) == set(e.tags),
STORE["events"],
)
)

def mock_queryset_filter(*args, **kwargs):
filter_events(**kwargs)
Expand All @@ -121,7 +142,9 @@ def mock_queryset_count(*args, **kwargs):
m_queryset.filter.side_effect = mock_queryset_filter
m_queryset.none.side_effect = mock_queryset_none
m_queryset.count.side_effect = mock_queryset_count
m_queryset.__getitem__.side_effect = STORE["events"].__getitem__
m_queryset.__getitem__.side_effect = lambda index: STORE[
"events"
].__getitem__(index)

return m_queryset

Expand Down
2 changes: 1 addition & 1 deletion graphene_django/filter/tests/filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ class Meta:
fields = {
"headline": ["exact", "icontains"],
"pub_date": ["gt", "lt", "exact"],
"reporter": ["exact"],
"reporter": ["exact", "in"],
}

order_by = OrderingFilter(fields=("pub_date",))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@


@pytest.mark.skipif(ArrayField is MissingType, reason="ArrayField should exist")
def test_string_contains_multiple(Query):
def test_array_field_contains_multiple(Query):
"""
Test contains filter on a string field.
Test contains filter on a array field of string.
"""

schema = Schema(query=Query)
Expand All @@ -32,9 +32,9 @@ def test_string_contains_multiple(Query):


@pytest.mark.skipif(ArrayField is MissingType, reason="ArrayField should exist")
def test_string_contains_one(Query):
def test_array_field_contains_one(Query):
"""
Test contains filter on a string field.
Test contains filter on a array field of string.
"""

schema = Schema(query=Query)
Expand All @@ -59,9 +59,9 @@ def test_string_contains_one(Query):


@pytest.mark.skipif(ArrayField is MissingType, reason="ArrayField should exist")
def test_string_contains_none(Query):
def test_array_field_contains_empty_list(Query):
"""
Test contains filter on a string field.
Test contains filter on a array field of string.
"""

schema = Schema(query=Query)
Expand All @@ -79,4 +79,9 @@ def test_string_contains_none(Query):
"""
result = schema.execute(query)
assert not result.errors
assert result.data["events"]["edges"] == []
assert result.data["events"]["edges"] == [
{"node": {"name": "Live Show"}},
{"node": {"name": "Musical"}},
{"node": {"name": "Ballet"}},
{"node": {"name": "Speech"}},
]
Loading

0 comments on commit 2d4ca0a

Please sign in to comment.