From 2d4ca0ac7b60413a7c5d9a5a995ad478c05d6352 Mon Sep 17 00:00:00 2001 From: Thomas Leonard <64223923+tcleonard@users.noreply.github.com> Date: Tue, 23 Feb 2021 05:21:32 +0100 Subject: [PATCH] Add enum support to filters and fix filter typing (v3) (#1119) * - Add filtering support for choice fields converted to graphql Enum (or not) - Fix type of various filters (used to default to String) - Fix bug with contains introduced in previous PR - Fix bug with declared filters being overridden (see PR #1108) - Fix support for ArrayField and add documentation * Fix for v3 Co-authored-by: Thomas Leonard --- docs/filtering.rst | 43 +++ graphene_django/converter.py | 2 +- graphene_django/filter/__init__.py | 11 +- graphene_django/filter/fields.py | 27 +- graphene_django/filter/filters.py | 32 +- graphene_django/filter/tests/conftest.py | 59 +++- graphene_django/filter/tests/filters.py | 2 +- ...py => test_array_field_contains_filter.py} | 19 +- .../tests/test_array_field_exact_filter.py | 129 +++++++ ....py => test_array_field_overlap_filter.py} | 12 +- .../filter/tests/test_enum_filtering.py | 160 +++++++++ graphene_django/filter/tests/test_fields.py | 53 ++- .../filter/tests/test_in_filter.py | 333 ++++++++++++++++-- .../filter/tests/test_range_filter.py | 1 + graphene_django/filter/utils.py | 137 ++++--- graphene_django/tests/models.py | 6 +- graphene_django/tests/test_query.py | 7 + graphene_django/tests/test_types.py | 1 + 18 files changed, 912 insertions(+), 122 deletions(-) rename graphene_django/filter/tests/{test_contains_filter.py => test_array_field_contains_filter.py} (74%) create mode 100644 graphene_django/filter/tests/test_array_field_exact_filter.py rename graphene_django/filter/tests/{test_overlap_filter.py => test_array_field_overlap_filter.py} (84%) create mode 100644 graphene_django/filter/tests/test_enum_filtering.py diff --git a/docs/filtering.rst b/docs/filtering.rst index 6a57bf928..beb5e5b43 100644 --- a/docs/filtering.rst +++ b/docs/filtering.rst @@ -258,3 +258,46 @@ with this set up, you can now order the users under group: } } } + + +PostgreSQL `ArrayField` +----------------------- + +Graphene provides an easy to implement filters on `ArrayField` as they are not natively supported by django_filters: + +.. code:: python + + from django.db import models + from django_filters import FilterSet, OrderingFilter + from graphene_django.filter import ArrayFilter + + class Event(models.Model): + name = models.CharField(max_length=50) + tags = ArrayField(models.CharField(max_length=50)) + + class EventFilterSet(FilterSet): + class Meta: + model = Event + fields = { + "name": ["exact", "contains"], + } + + tags__contains = ArrayFilter(field_name="tags", lookup_expr="contains") + tags__overlap = ArrayFilter(field_name="tags", lookup_expr="overlap") + tags = ArrayFilter(field_name="tags", lookup_expr="exact") + + class EventType(DjangoObjectType): + class Meta: + model = Event + interfaces = (Node,) + filterset_class = EventFilterSet + +with this set up, you can now filter events by tags: + +.. code:: + + query { + events(tags_Overlap: ["concert", "festival"]) { + name + } + } diff --git a/graphene_django/converter.py b/graphene_django/converter.py index 525bb11e3..6bbf53452 100644 --- a/graphene_django/converter.py +++ b/graphene_django/converter.py @@ -35,7 +35,7 @@ class BlankValueField(Field): - def get_resolver(self, parent_resolver): + def wrap_resolve(self, parent_resolver): resolver = self.resolver or parent_resolver # create custom resolver diff --git a/graphene_django/filter/__init__.py b/graphene_django/filter/__init__.py index 5de36adce..94570c98c 100644 --- a/graphene_django/filter/__init__.py +++ b/graphene_django/filter/__init__.py @@ -9,10 +9,19 @@ ) else: from .fields import DjangoFilterConnectionField - from .filters import GlobalIDFilter, GlobalIDMultipleChoiceFilter + from .filters import ( + ArrayFilter, + GlobalIDFilter, + GlobalIDMultipleChoiceFilter, + ListFilter, + RangeFilter, + ) __all__ = [ "DjangoFilterConnectionField", "GlobalIDFilter", "GlobalIDMultipleChoiceFilter", + "ArrayFilter", + "ListFilter", + "RangeFilter", ] diff --git a/graphene_django/filter/fields.py b/graphene_django/filter/fields.py index 244fb3987..c6dd50ee6 100644 --- a/graphene_django/filter/fields.py +++ b/graphene_django/filter/fields.py @@ -2,12 +2,31 @@ from functools import partial from django.core.exceptions import ValidationError + +from graphene.types.enum import EnumType from graphene.types.argument import to_arguments from graphene.utils.str_converters import to_snake_case + from ..fields import DjangoConnectionField from .utils import get_filtering_args_from_filterset, get_filterset_class +def convert_enum(data): + """ + Check if the data is a enum option (or potentially nested list of enum option) + and convert it to its value. + + This method is used to pre-process the data for the filters as they can take an + graphene.Enum as argument, but filters (from django_filters) expect a simple value. + """ + if isinstance(data, list): + return [convert_enum(item) for item in data] + if isinstance(type(data), EnumType): + return data.value + else: + return data + + class DjangoFilterConnectionField(DjangoConnectionField): def __init__( self, @@ -43,8 +62,8 @@ def filterset_class(self): if self._extra_filter_meta: meta.update(self._extra_filter_meta) - filterset_class = self._provided_filterset_class or ( - self.node_type._meta.filterset_class + filterset_class = ( + self._provided_filterset_class or self.node_type._meta.filterset_class ) self._filterset_class = get_filterset_class(filterset_class, **meta) @@ -68,7 +87,7 @@ def filter_kwargs(): if k in filtering_args: if k == "order_by" and v is not None: v = to_snake_case(v) - kwargs[k] = v + kwargs[k] = convert_enum(v) return kwargs qs = super(DjangoFilterConnectionField, cls).resolve_queryset( @@ -78,7 +97,7 @@ def filter_kwargs(): filterset = filterset_class( data=filter_kwargs(), queryset=qs, request=info.context ) - if filterset.form.is_valid(): + if filterset.is_valid(): return filterset.qs raise ValidationError(filterset.form.errors.as_json()) diff --git a/graphene_django/filter/filters.py b/graphene_django/filter/filters.py index 58d7d0852..e23626a73 100644 --- a/graphene_django/filter/filters.py +++ b/graphene_django/filter/filters.py @@ -2,6 +2,7 @@ from django.forms import Field from django_filters import Filter, MultipleChoiceFilter +from django_filters.constants import EMPTY_VALUES from graphql_relay.node.node import from_global_id @@ -31,14 +32,15 @@ def filter(self, qs, value): return super(GlobalIDMultipleChoiceFilter, self).filter(qs, gids) -class InFilter(Filter): +class ListFilter(Filter): """ - Filter for a list of value using the `__in` Django filter. + Filter that takes a list of value as input. + It is for example used for `__in` filters. """ def filter(self, qs, value): """ - Override the default filter class to check first weather the list is + Override the default filter class to check first whether the list is empty or not. This needs to be done as in this case we expect to get an empty output (if not an exclude filter) but django_filter consider an empty list @@ -73,3 +75,27 @@ class RangeField(Field): class RangeFilter(Filter): field_class = RangeField + + +class ArrayFilter(Filter): + """ + Filter made for PostgreSQL ArrayField. + """ + + def filter(self, qs, value): + """ + Override the default filter class to check first whether the list is + empty or not. + This needs to be done as in this case we expect to get the filter applied with + an empty list since it's a valid value but django_filter consider an empty list + to be an empty input value (see `EMPTY_VALUES`) meaning that + the filter does not need to be applied (hence returning the original + queryset). + """ + if value in EMPTY_VALUES and value != []: + return qs + if self.distinct: + qs = qs.distinct() + lookup = "%s__%s" % (self.field_name, self.lookup_expr) + qs = self.get_method(qs)(**{lookup: value}) + return qs diff --git a/graphene_django/filter/tests/conftest.py b/graphene_django/filter/tests/conftest.py index 031364519..57924aff6 100644 --- a/graphene_django/filter/tests/conftest.py +++ b/graphene_django/filter/tests/conftest.py @@ -9,6 +9,7 @@ from graphene.relay import Node from graphene_django import DjangoObjectType from graphene_django.utils import DJANGO_FILTER_INSTALLED +from graphene_django.filter import ArrayFilter, ListFilter from ...compat import ArrayField @@ -27,49 +28,61 @@ STORE = {"events": []} -@pytest.fixture -def Event(): - class Event(models.Model): - name = models.CharField(max_length=50) - tags = ArrayField(models.CharField(max_length=50)) - - return Event +class Event(models.Model): + name = models.CharField(max_length=50) + tags = ArrayField(models.CharField(max_length=50)) + tag_ids = ArrayField(models.IntegerField()) + random_field = ArrayField(models.BooleanField()) @pytest.fixture -def EventFilterSet(Event): - - from django.contrib.postgres.forms import SimpleArrayField - - class ArrayFilter(filters.Filter): - base_field_class = SimpleArrayField - +def EventFilterSet(): class EventFilterSet(FilterSet): class Meta: model = Event fields = { - "name": ["exact"], + "name": ["exact", "contains"], } + # Those are actually usable with our Query fixture bellow tags__contains = ArrayFilter(field_name="tags", lookup_expr="contains") tags__overlap = ArrayFilter(field_name="tags", lookup_expr="overlap") + tags = ArrayFilter(field_name="tags", lookup_expr="exact") + + # Those are actually not usable and only to check type declarations + tags_ids__contains = ArrayFilter(field_name="tag_ids", lookup_expr="contains") + tags_ids__overlap = ArrayFilter(field_name="tag_ids", lookup_expr="overlap") + tags_ids = ArrayFilter(field_name="tag_ids", lookup_expr="exact") + random_field__contains = ArrayFilter( + field_name="random_field", lookup_expr="contains" + ) + random_field__overlap = ArrayFilter( + field_name="random_field", lookup_expr="overlap" + ) + random_field = ArrayFilter(field_name="random_field", lookup_expr="exact") return EventFilterSet @pytest.fixture -def EventType(Event, EventFilterSet): +def EventType(EventFilterSet): class EventType(DjangoObjectType): class Meta: model = Event interfaces = (Node,) + fields = "__all__" filterset_class = EventFilterSet return EventType @pytest.fixture -def Query(Event, EventType): +def Query(EventType): + """ + Note that we have to use a custom resolver to replicate the arrayfield filter behavior as + we are running unit tests in sqlite which does not have ArrayFields. + """ + class Query(graphene.ObjectType): events = DjangoFilterConnectionField(EventType) @@ -79,6 +92,7 @@ def resolve_events(self, info, **kwargs): Event(name="Live Show", tags=["concert", "music", "rock"],), Event(name="Musical", tags=["movie", "music"],), Event(name="Ballet", tags=["concert", "dance"],), + Event(name="Speech", tags=[],), ] STORE["events"] = events @@ -105,6 +119,13 @@ def filter_events(**kwargs): STORE["events"], ) ) + if "tags__exact" in kwargs: + STORE["events"] = list( + filter( + lambda e: set(kwargs["tags__exact"]) == set(e.tags), + STORE["events"], + ) + ) def mock_queryset_filter(*args, **kwargs): filter_events(**kwargs) @@ -121,7 +142,9 @@ def mock_queryset_count(*args, **kwargs): m_queryset.filter.side_effect = mock_queryset_filter m_queryset.none.side_effect = mock_queryset_none m_queryset.count.side_effect = mock_queryset_count - m_queryset.__getitem__.side_effect = STORE["events"].__getitem__ + m_queryset.__getitem__.side_effect = lambda index: STORE[ + "events" + ].__getitem__(index) return m_queryset diff --git a/graphene_django/filter/tests/filters.py b/graphene_django/filter/tests/filters.py index 43b6a878d..a7443c07f 100644 --- a/graphene_django/filter/tests/filters.py +++ b/graphene_django/filter/tests/filters.py @@ -10,7 +10,7 @@ class Meta: fields = { "headline": ["exact", "icontains"], "pub_date": ["gt", "lt", "exact"], - "reporter": ["exact"], + "reporter": ["exact", "in"], } order_by = OrderingFilter(fields=("pub_date",)) diff --git a/graphene_django/filter/tests/test_contains_filter.py b/graphene_django/filter/tests/test_array_field_contains_filter.py similarity index 74% rename from graphene_django/filter/tests/test_contains_filter.py rename to graphene_django/filter/tests/test_array_field_contains_filter.py index 35e775ef5..4144614c7 100644 --- a/graphene_django/filter/tests/test_contains_filter.py +++ b/graphene_django/filter/tests/test_array_field_contains_filter.py @@ -6,9 +6,9 @@ @pytest.mark.skipif(ArrayField is MissingType, reason="ArrayField should exist") -def test_string_contains_multiple(Query): +def test_array_field_contains_multiple(Query): """ - Test contains filter on a string field. + Test contains filter on a array field of string. """ schema = Schema(query=Query) @@ -32,9 +32,9 @@ def test_string_contains_multiple(Query): @pytest.mark.skipif(ArrayField is MissingType, reason="ArrayField should exist") -def test_string_contains_one(Query): +def test_array_field_contains_one(Query): """ - Test contains filter on a string field. + Test contains filter on a array field of string. """ schema = Schema(query=Query) @@ -59,9 +59,9 @@ def test_string_contains_one(Query): @pytest.mark.skipif(ArrayField is MissingType, reason="ArrayField should exist") -def test_string_contains_none(Query): +def test_array_field_contains_empty_list(Query): """ - Test contains filter on a string field. + Test contains filter on a array field of string. """ schema = Schema(query=Query) @@ -79,4 +79,9 @@ def test_string_contains_none(Query): """ result = schema.execute(query) assert not result.errors - assert result.data["events"]["edges"] == [] + assert result.data["events"]["edges"] == [ + {"node": {"name": "Live Show"}}, + {"node": {"name": "Musical"}}, + {"node": {"name": "Ballet"}}, + {"node": {"name": "Speech"}}, + ] diff --git a/graphene_django/filter/tests/test_array_field_exact_filter.py b/graphene_django/filter/tests/test_array_field_exact_filter.py new file mode 100644 index 000000000..b07abede5 --- /dev/null +++ b/graphene_django/filter/tests/test_array_field_exact_filter.py @@ -0,0 +1,129 @@ +import pytest + +from graphene import Schema + +from ...compat import ArrayField, MissingType + + +@pytest.mark.skipif(ArrayField is MissingType, reason="ArrayField should exist") +def test_array_field_exact_no_match(Query): + """ + Test exact filter on a array field of string. + """ + + schema = Schema(query=Query) + + query = """ + query { + events (tags: ["concert", "music"]) { + edges { + node { + name + } + } + } + } + """ + result = schema.execute(query) + assert not result.errors + assert result.data["events"]["edges"] == [] + + +@pytest.mark.skipif(ArrayField is MissingType, reason="ArrayField should exist") +def test_array_field_exact_match(Query): + """ + Test exact filter on a array field of string. + """ + + schema = Schema(query=Query) + + query = """ + query { + events (tags: ["movie", "music"]) { + edges { + node { + name + } + } + } + } + """ + result = schema.execute(query) + assert not result.errors + assert result.data["events"]["edges"] == [ + {"node": {"name": "Musical"}}, + ] + + +@pytest.mark.skipif(ArrayField is MissingType, reason="ArrayField should exist") +def test_array_field_exact_empty_list(Query): + """ + Test exact filter on a array field of string. + """ + + schema = Schema(query=Query) + + query = """ + query { + events (tags: []) { + edges { + node { + name + } + } + } + } + """ + result = schema.execute(query) + assert not result.errors + assert result.data["events"]["edges"] == [ + {"node": {"name": "Speech"}}, + ] + + +def test_array_field_filter_schema_type(Query): + """ + Check that the type in the filter is an array field like on the object type. + """ + schema = Schema(query=Query) + schema_str = str(schema) + + assert ( + '''type EventType implements Node { + """The ID of the object""" + id: ID! + name: String! + tags: [String!]! + tagIds: [Int!]! + randomField: [Boolean!]! +}''' + in schema_str + ) + + filters = { + "offset": "Int", + "before": "String", + "after": "String", + "first": "Int", + "last": "Int", + "name": "String", + "name_Contains": "String", + "tags_Contains": "[String!]", + "tags_Overlap": "[String!]", + "tags": "[String!]", + "tagsIds_Contains": "[Int!]", + "tagsIds_Overlap": "[Int!]", + "tagsIds": "[Int!]", + "randomField_Contains": "[Boolean!]", + "randomField_Overlap": "[Boolean!]", + "randomField": "[Boolean!]", + } + filters_str = ", ".join( + [ + f"{filter_field}: {gql_type} = null" + for filter_field, gql_type in filters.items() + ] + ) + assert ( + f"type Query {{\n events({filters_str}): EventTypeConnection\n}}" in schema_str + ) diff --git a/graphene_django/filter/tests/test_overlap_filter.py b/graphene_django/filter/tests/test_array_field_overlap_filter.py similarity index 84% rename from graphene_django/filter/tests/test_overlap_filter.py rename to graphene_django/filter/tests/test_array_field_overlap_filter.py index 32dfa44a1..5ce1576b3 100644 --- a/graphene_django/filter/tests/test_overlap_filter.py +++ b/graphene_django/filter/tests/test_array_field_overlap_filter.py @@ -6,9 +6,9 @@ @pytest.mark.skipif(ArrayField is MissingType, reason="ArrayField should exist") -def test_string_overlap_multiple(Query): +def test_array_field_overlap_multiple(Query): """ - Test overlap filter on a string field. + Test overlap filter on a array field of string. """ schema = Schema(query=Query) @@ -34,9 +34,9 @@ def test_string_overlap_multiple(Query): @pytest.mark.skipif(ArrayField is MissingType, reason="ArrayField should exist") -def test_string_overlap_one(Query): +def test_array_field_overlap_one(Query): """ - Test overlap filter on a string field. + Test overlap filter on a array field of string. """ schema = Schema(query=Query) @@ -61,9 +61,9 @@ def test_string_overlap_one(Query): @pytest.mark.skipif(ArrayField is MissingType, reason="ArrayField should exist") -def test_string_overlap_none(Query): +def test_array_field_overlap_empty_list(Query): """ - Test overlap filter on a string field. + Test overlap filter on a array field of string. """ schema = Schema(query=Query) diff --git a/graphene_django/filter/tests/test_enum_filtering.py b/graphene_django/filter/tests/test_enum_filtering.py new file mode 100644 index 000000000..09c69b393 --- /dev/null +++ b/graphene_django/filter/tests/test_enum_filtering.py @@ -0,0 +1,160 @@ +import pytest + +import graphene +from graphene.relay import Node + +from graphene_django import DjangoObjectType, DjangoConnectionField +from graphene_django.tests.models import Article, Reporter +from graphene_django.utils import DJANGO_FILTER_INSTALLED + +pytestmark = [] + +if DJANGO_FILTER_INSTALLED: + from graphene_django.filter import DjangoFilterConnectionField +else: + pytestmark.append( + pytest.mark.skipif( + True, reason="django_filters not installed or not compatible" + ) + ) + + +@pytest.fixture +def schema(): + class ReporterType(DjangoObjectType): + class Meta: + model = Reporter + interfaces = (Node,) + fields = "__all__" + + class ArticleType(DjangoObjectType): + class Meta: + model = Article + interfaces = (Node,) + fields = "__all__" + filter_fields = { + "lang": ["exact", "in"], + "reporter__a_choice": ["exact", "in"], + } + + class Query(graphene.ObjectType): + all_reporters = DjangoConnectionField(ReporterType) + all_articles = DjangoFilterConnectionField(ArticleType) + + schema = graphene.Schema(query=Query) + return schema + + +@pytest.fixture +def reporter_article_data(): + john = Reporter.objects.create( + first_name="John", last_name="Doe", email="johndoe@example.com", a_choice=1 + ) + jane = Reporter.objects.create( + first_name="Jane", last_name="Doe", email="janedoe@example.com", a_choice=2 + ) + Article.objects.create( + headline="Article Node 1", reporter=john, editor=john, lang="es", + ) + Article.objects.create( + headline="Article Node 2", reporter=john, editor=john, lang="en", + ) + Article.objects.create( + headline="Article Node 3", reporter=jane, editor=jane, lang="en", + ) + + +def test_filter_enum_on_connection(schema, reporter_article_data): + """ + Check that we can filter with enums on a connection. + """ + query = """ + query { + allArticles(lang: ES) { + edges { + node { + headline + } + } + } + } + """ + + expected = {"allArticles": {"edges": [{"node": {"headline": "Article Node 1"}},]}} + + result = schema.execute(query) + assert not result.errors + assert result.data == expected + + +def test_filter_on_foreign_key_enum_field(schema, reporter_article_data): + """ + Check that we can filter with enums on a field from a foreign key. + """ + query = """ + query { + allArticles(reporter_AChoice: A_1) { + edges { + node { + headline + } + } + } + } + """ + + expected = { + "allArticles": { + "edges": [ + {"node": {"headline": "Article Node 1"}}, + {"node": {"headline": "Article Node 2"}}, + ] + } + } + + result = schema.execute(query) + assert not result.errors + assert result.data == expected + + +def test_filter_enum_field_schema_type(schema): + """ + Check that the type in the filter is an enum like on the object type. + """ + schema_str = str(schema) + + assert ( + '''type ArticleType implements Node { + """The ID of the object""" + id: ID! + headline: String! + pubDate: Date! + pubDateTime: DateTime! + reporter: ReporterType! + editor: ReporterType! + + """Language""" + lang: TestsArticleLangChoices! + importance: TestsArticleImportanceChoices +}''' + in schema_str + ) + + filters = { + "offset": "Int", + "before": "String", + "after": "String", + "first": "Int", + "last": "Int", + "lang": "TestsArticleLangChoices", + "lang_In": "[TestsArticleLangChoices]", + "reporter_AChoice": "TestsReporterAChoiceChoices", + "reporter_AChoice_In": "[TestsReporterAChoiceChoices]", + } + filters_str = ", ".join( + [ + f"{filter_field}: {gql_type} = null" + for filter_field, gql_type in filters.items() + ] + ) + assert f" allArticles({filters_str}): ArticleTypeConnection\n" in schema_str diff --git a/graphene_django/filter/tests/test_fields.py b/graphene_django/filter/tests/test_fields.py index 9c94f06e9..274f6ac8c 100644 --- a/graphene_django/filter/tests/test_fields.py +++ b/graphene_django/filter/tests/test_fields.py @@ -9,7 +9,7 @@ from graphene.relay import Node from graphene_django import DjangoObjectType from graphene_django.forms import GlobalIDFormField, GlobalIDMultipleChoiceField -from graphene_django.tests.models import Article, Pet, Reporter +from graphene_django.tests.models import Article, Person, Pet, Reporter from graphene_django.utils import DJANGO_FILTER_INSTALLED pytestmark = [] @@ -90,6 +90,7 @@ def test_filter_explicit_filterset_arguments(): "pub_date__gt", "pub_date__lt", "reporter", + "reporter__in", ) @@ -696,7 +697,7 @@ def resolve_all_reporters(self, info, **args): node { id firstName - articles(lang: "es") { + articles(lang: ES) { edges { node { id @@ -738,6 +739,7 @@ class ReporterType(DjangoObjectType): class Meta: model = Reporter interfaces = (Node,) + fields = "__all__" class Query(ObjectType): all_reporters = DjangoFilterConnectionField( @@ -1143,7 +1145,7 @@ def get_filters(cls): return filters - def filter_email_in(cls, queryset, name, value): + def filter_email_in(self, queryset, name, value): return queryset.filter(**{name: [value]}) class NewArticleFilter(ArticleFilterMixin, ArticleFilter): @@ -1228,3 +1230,48 @@ class Query(ObjectType): assert not result.errors assert result.data == expected + + +def test_filter_string_contains(): + class PersonType(DjangoObjectType): + class Meta: + model = Person + interfaces = (Node,) + fields = "__all__" + filter_fields = {"name": ["exact", "in", "contains", "icontains"]} + + class Query(ObjectType): + people = DjangoFilterConnectionField(PersonType) + + schema = Schema(query=Query) + + Person.objects.bulk_create( + [ + Person(name="Jack"), + Person(name="Joe"), + Person(name="Jane"), + Person(name="Peter"), + Person(name="Bob"), + ] + ) + query = """query nameContain($filter: String) { + people(name_Contains: $filter) { + edges { + node { + name + } + } + } + }""" + + result = schema.execute(query, variables={"filter": "Ja"}) + assert not result.errors + assert result.data == { + "people": {"edges": [{"node": {"name": "Jack"}}, {"node": {"name": "Jane"}},]} + } + + result = schema.execute(query, variables={"filter": "o"}) + assert not result.errors + assert result.data == { + "people": {"edges": [{"node": {"name": "Joe"}}, {"node": {"name": "Bob"}},]} + } diff --git a/graphene_django/filter/tests/test_in_filter.py b/graphene_django/filter/tests/test_in_filter.py index 9e9c32378..7ad0286ac 100644 --- a/graphene_django/filter/tests/test_in_filter.py +++ b/graphene_django/filter/tests/test_in_filter.py @@ -1,3 +1,5 @@ +from datetime import datetime + import pytest from django_filters import FilterSet @@ -5,7 +7,8 @@ from graphene import ObjectType, Schema from graphene.relay import Node from graphene_django import DjangoObjectType -from graphene_django.tests.models import Pet, Person +from graphene_django.tests.models import Pet, Person, Reporter, Article, Film +from graphene_django.filter.tests.filters import ArticleFilter from graphene_django.utils import DJANGO_FILTER_INSTALLED pytestmark = [] @@ -20,40 +23,77 @@ ) -class PetNode(DjangoObjectType): - class Meta: - model = Pet - interfaces = (Node,) - filter_fields = { - "name": ["exact", "in"], - "age": ["exact", "in", "range"], - } +@pytest.fixture +def query(): + class PetNode(DjangoObjectType): + class Meta: + model = Pet + interfaces = (Node,) + fields = "__all__" + filter_fields = { + "id": ["exact", "in"], + "name": ["exact", "in"], + "age": ["exact", "in", "range"], + } + + class ReporterNode(DjangoObjectType): + class Meta: + model = Reporter + interfaces = (Node,) + fields = "__all__" + # choice filter using enum + filter_fields = {"reporter_type": ["exact", "in"]} + class ArticleNode(DjangoObjectType): + class Meta: + model = Article + interfaces = (Node,) + fields = "__all__" + filterset_class = ArticleFilter -class PersonFilterSet(FilterSet): - class Meta: - model = Person - fields = {} + class FilmNode(DjangoObjectType): + class Meta: + model = Film + interfaces = (Node,) + fields = "__all__" + # choice filter not using enum + filter_fields = { + "genre": ["exact", "in"], + } + convert_choices_to_enum = False - names = filters.BaseInFilter(method="filter_names") + class PersonFilterSet(FilterSet): + class Meta: + model = Person + fields = {"name": ["in"]} - def filter_names(self, qs, name, value): - return qs.filter(name__in=value) + names = filters.BaseInFilter(method="filter_names") + def filter_names(self, qs, name, value): + """ + This custom filter take a string as input with comma separated values. + Note that the value here is already a list as it has been transformed by the BaseInFilter class. + """ + return qs.filter(name__in=value) -class PersonNode(DjangoObjectType): - class Meta: - model = Person - interfaces = (Node,) - filterset_class = PersonFilterSet + class PersonNode(DjangoObjectType): + class Meta: + model = Person + interfaces = (Node,) + filterset_class = PersonFilterSet + fields = "__all__" + class Query(ObjectType): + pets = DjangoFilterConnectionField(PetNode) + people = DjangoFilterConnectionField(PersonNode) + articles = DjangoFilterConnectionField(ArticleNode) + films = DjangoFilterConnectionField(FilmNode) + reporters = DjangoFilterConnectionField(ReporterNode) -class Query(ObjectType): - pets = DjangoFilterConnectionField(PetNode) - people = DjangoFilterConnectionField(PersonNode) + return Query -def test_string_in_filter(): +def test_string_in_filter(query): """ Test in filter on a string field. """ @@ -61,7 +101,7 @@ def test_string_in_filter(): Pet.objects.create(name="Mimi", age=3) Pet.objects.create(name="Jojo, the rabbit", age=3) - schema = Schema(query=Query) + schema = Schema(query=query) query = """ query { @@ -82,17 +122,48 @@ def test_string_in_filter(): ] -def test_string_in_filter_with_filterset_class(): - """Test in filter on a string field with a custom filterset class.""" +def test_string_in_filter_with_otjer_filter(query): + """ + Test in filter on a string field which has also a custom filter doing a similar operation. + """ + Person.objects.create(name="John") + Person.objects.create(name="Michael") + Person.objects.create(name="Angela") + + schema = Schema(query=query) + + query = """ + query { + people (name_In: ["John", "Michael"]) { + edges { + node { + name + } + } + } + } + """ + result = schema.execute(query) + assert not result.errors + assert result.data["people"]["edges"] == [ + {"node": {"name": "John"}}, + {"node": {"name": "Michael"}}, + ] + + +def test_string_in_filter_with_declared_filter(query): + """ + Test in filter on a string field with a custom filterset class. + """ Person.objects.create(name="John") Person.objects.create(name="Michael") Person.objects.create(name="Angela") - schema = Schema(query=Query) + schema = Schema(query=query) query = """ query { - people (names: ["John", "Michael"]) { + people (names: "John,Michael") { edges { node { name @@ -109,7 +180,7 @@ def test_string_in_filter_with_filterset_class(): ] -def test_int_in_filter(): +def test_int_in_filter(query): """ Test in filter on an integer field. """ @@ -117,7 +188,7 @@ def test_int_in_filter(): Pet.objects.create(name="Mimi", age=3) Pet.objects.create(name="Jojo, the rabbit", age=3) - schema = Schema(query=Query) + schema = Schema(query=query) query = """ query { @@ -157,7 +228,7 @@ def test_int_in_filter(): ] -def test_in_filter_with_empty_list(): +def test_in_filter_with_empty_list(query): """ Check that using a in filter with an empty list provided as input returns no objects. """ @@ -165,7 +236,7 @@ def test_in_filter_with_empty_list(): Pet.objects.create(name="Mimi", age=8) Pet.objects.create(name="Picotin", age=5) - schema = Schema(query=Query) + schema = Schema(query=query) query = """ query { @@ -181,3 +252,197 @@ def test_in_filter_with_empty_list(): result = schema.execute(query) assert not result.errors assert len(result.data["pets"]["edges"]) == 0 + + +def test_choice_in_filter_without_enum(query): + """ + Test in filter o an choice field not using an enum (Film.genre). + """ + + john_doe = Reporter.objects.create( + first_name="John", last_name="Doe", email="john@doe.com" + ) + jean_bon = Reporter.objects.create( + first_name="Jean", last_name="Bon", email="jean@bon.com" + ) + documentary_film = Film.objects.create(genre="do") + documentary_film.reporters.add(john_doe) + action_film = Film.objects.create(genre="ac") + action_film.reporters.add(john_doe) + other_film = Film.objects.create(genre="ot") + other_film.reporters.add(john_doe) + other_film.reporters.add(jean_bon) + + schema = Schema(query=query) + + query = """ + query { + films (genre_In: ["do", "ac"]) { + edges { + node { + genre + reporters { + edges { + node { + lastName + } + } + } + } + } + } + } + """ + result = schema.execute(query) + assert not result.errors + assert result.data["films"]["edges"] == [ + { + "node": { + "genre": "do", + "reporters": {"edges": [{"node": {"lastName": "Doe"}}]}, + } + }, + { + "node": { + "genre": "ac", + "reporters": {"edges": [{"node": {"lastName": "Doe"}}]}, + } + }, + ] + + +def test_fk_id_in_filter(query): + """ + Test in filter on an foreign key relationship. + """ + john_doe = Reporter.objects.create( + first_name="John", last_name="Doe", email="john@doe.com" + ) + jean_bon = Reporter.objects.create( + first_name="Jean", last_name="Bon", email="jean@bon.com" + ) + sara_croche = Reporter.objects.create( + first_name="Sara", last_name="Croche", email="sara@croche.com" + ) + Article.objects.create( + headline="A", + pub_date=datetime.now(), + pub_date_time=datetime.now(), + reporter=john_doe, + editor=john_doe, + ) + Article.objects.create( + headline="B", + pub_date=datetime.now(), + pub_date_time=datetime.now(), + reporter=jean_bon, + editor=jean_bon, + ) + Article.objects.create( + headline="C", + pub_date=datetime.now(), + pub_date_time=datetime.now(), + reporter=sara_croche, + editor=sara_croche, + ) + + schema = Schema(query=query) + + query = """ + query { + articles (reporter_In: [%s, %s]) { + edges { + node { + headline + reporter { + lastName + } + } + } + } + } + """ % ( + john_doe.id, + jean_bon.id, + ) + result = schema.execute(query) + assert not result.errors + assert result.data["articles"]["edges"] == [ + {"node": {"headline": "A", "reporter": {"lastName": "Doe"}}}, + {"node": {"headline": "B", "reporter": {"lastName": "Bon"}}}, + ] + + +def test_enum_in_filter(query): + """ + Test in filter on a choice field using an enum (Reporter.reporter_type). + """ + + Reporter.objects.create( + first_name="John", last_name="Doe", email="john@doe.com", reporter_type=1 + ) + Reporter.objects.create( + first_name="Jean", last_name="Bon", email="jean@bon.com", reporter_type=2 + ) + Reporter.objects.create( + first_name="Jane", last_name="Doe", email="jane@doe.com", reporter_type=2 + ) + Reporter.objects.create( + first_name="Jack", last_name="Black", email="jack@black.com", reporter_type=None + ) + + schema = Schema(query=query) + + query = """ + query { + reporters (reporterType_In: [A_1]) { + edges { + node { + email + } + } + } + } + """ + result = schema.execute(query) + assert not result.errors + assert result.data["reporters"]["edges"] == [ + {"node": {"email": "john@doe.com"}}, + ] + + query = """ + query { + reporters (reporterType_In: [A_2]) { + edges { + node { + email + } + } + } + } + """ + result = schema.execute(query) + assert not result.errors + assert result.data["reporters"]["edges"] == [ + {"node": {"email": "jean@bon.com"}}, + {"node": {"email": "jane@doe.com"}}, + ] + + query = """ + query { + reporters (reporterType_In: [A_2, A_1]) { + edges { + node { + email + } + } + } + } + """ + result = schema.execute(query) + assert not result.errors + assert result.data["reporters"]["edges"] == [ + {"node": {"email": "john@doe.com"}}, + {"node": {"email": "jean@bon.com"}}, + {"node": {"email": "jane@doe.com"}}, + ] diff --git a/graphene_django/filter/tests/test_range_filter.py b/graphene_django/filter/tests/test_range_filter.py index 644ec5df4..6227a7071 100644 --- a/graphene_django/filter/tests/test_range_filter.py +++ b/graphene_django/filter/tests/test_range_filter.py @@ -25,6 +25,7 @@ class PetNode(DjangoObjectType): class Meta: model = Pet interfaces = (Node,) + fields = "__all__" filter_fields = { "name": ["exact", "in"], "age": ["exact", "in", "range"], diff --git a/graphene_django/filter/utils.py b/graphene_django/filter/utils.py index 4530599e5..d4fc1bf84 100644 --- a/graphene_django/filter/utils.py +++ b/graphene_django/filter/utils.py @@ -1,53 +1,101 @@ import graphene -from django_filters.utils import get_model_field +from django import forms + +from django_filters.utils import get_model_field, get_field_parts from django_filters.filters import Filter, BaseCSVFilter from .filterset import custom_filterset_factory, setup_filterset -from .filters import InFilter, RangeFilter +from .filters import ArrayFilter, ListFilter, RangeFilter +from ..forms import GlobalIDFormField, GlobalIDMultipleChoiceField + + +def get_field_type(registry, model, field_name): + """ + Try to get a model field corresponding Graphql type from the DjangoObjectType. + """ + object_type = registry.get_type_for_model(model) + if object_type: + object_type_field = object_type._meta.fields.get(field_name) + if object_type_field: + field_type = object_type_field.type + if isinstance(field_type, graphene.NonNull): + field_type = field_type.of_type + return field_type + return None def get_filtering_args_from_filterset(filterset_class, type): - """ Inspect a FilterSet and produce the arguments to pass to - a Graphene Field. These arguments will be available to - filter against in the GraphQL + """ + Inspect a FilterSet and produce the arguments to pass to a Graphene Field. + These arguments will be available to filter against in the GraphQL API. """ from ..forms.converter import convert_form_field args = {} model = filterset_class._meta.model + registry = type._meta.registry for name, filter_field in filterset_class.base_filters.items(): - form_field = None filter_type = filter_field.lookup_expr + field_type = None + form_field = None - if name in filterset_class.declared_filters: - # Get the filter field from the explicitly declared filter - form_field = filter_field.field - field = convert_form_field(form_field) - else: - # Get the filter field with no explicit type declaration - model_field = get_model_field(model, filter_field.field_name) - if filter_type != "isnull" and hasattr(model_field, "formfield"): - form_field = model_field.formfield( - required=filter_field.extra.get("required", False) - ) - - # Fallback to field defined on filter if we can't get it from the - # model field - if not form_field: - form_field = filter_field.field - - field = convert_form_field(form_field) - - if filter_type in {"in", "range", "contains", "overlap"}: - # Replace CSV filters (`in`, `range`, `contains`, `overlap`) argument type to be a list of - # the same type as the field. See comments in - # `replace_csv_filters` method for more details. - field = graphene.List(field.get_type()) - - field_type = field.Argument() - field_type.description = str(filter_field.label) if filter_field.label else None - args[name] = field_type + if ( + name not in filterset_class.declared_filters + or isinstance(filter_field, ListFilter) + or isinstance(filter_field, RangeFilter) + or isinstance(filter_field, ArrayFilter) + ): + # Get the filter field for filters that are no explicitly declared. + + required = filter_field.extra.get("required", False) + if filter_type == "isnull": + field = graphene.Boolean(required=required) + else: + model_field = get_model_field(model, filter_field.field_name) + + # Get the form field either from: + # 1. the formfield corresponding to the model field + # 2. the field defined on filter + if hasattr(model_field, "formfield"): + form_field = model_field.formfield(required=required) + if not form_field: + form_field = filter_field.field + + # First try to get the matching field type from the GraphQL DjangoObjectType + if model_field: + if ( + isinstance(form_field, forms.ModelChoiceField) + or isinstance(form_field, forms.ModelMultipleChoiceField) + or isinstance(form_field, GlobalIDMultipleChoiceField) + or isinstance(form_field, GlobalIDFormField) + ): + # Foreign key have dynamic types and filtering on a foreign key actually means filtering on its ID. + field_type = get_field_type( + registry, model_field.related_model, "id" + ) + else: + field_type = get_field_type( + registry, model_field.model, model_field.name + ) + + if not field_type: + # Fallback on converting the form field either because: + # - it's an explicitly declared filters + # - we did not manage to get the type from the model type + form_field = form_field or filter_field.field + field_type = convert_form_field(form_field) + + if isinstance(filter_field, ListFilter) or isinstance( + filter_field, RangeFilter + ): + # Replace InFilter/RangeFilter filters (`in`, `range`) argument type to be a list of + # the same type as the field. See comments in `replace_csv_filters` method for more details. + field_type = graphene.List(field_type.get_type()) + + args[name] = graphene.Argument( + field_type.get_type(), description=filter_field.label, required=required, + ) return args @@ -69,18 +117,26 @@ def get_filterset_class(filterset_class, **meta): def replace_csv_filters(filterset_class): """ - Replace the "in", "contains", "overlap" and "range" filters (that are not explicitly declared) to not be BaseCSVFilter (BaseInFilter, BaseRangeFilter) objects anymore - but regular Filter objects that simply use the input value as filter argument on the queryset. + Replace the "in" and "range" filters (that are not explicitly declared) + to not be BaseCSVFilter (BaseInFilter, BaseRangeFilter) objects anymore + but our custom InFilter/RangeFilter filter class that use the input + value as filter argument on the queryset. - This is because those BaseCSVFilter are expecting a string as input with comma separated value but with GraphQl we - can actually have a list as input and have a proper type verification of each value in the list. + This is because those BaseCSVFilter are expecting a string as input with + comma separated values. + But with GraphQl we can actually have a list as input and have a proper + type verification of each value in the list. See issue https://github.com/graphql-python/graphene-django/issues/1068. """ for name, filter_field in list(filterset_class.base_filters.items()): + # Do not touch any declared filters + if name in filterset_class.declared_filters: + continue + filter_type = filter_field.lookup_expr - if filter_type in {"in", "contains", "overlap"}: - filterset_class.base_filters[name] = InFilter( + if filter_type == "in": + filterset_class.base_filters[name] = ListFilter( field_name=filter_field.field_name, lookup_expr=filter_field.lookup_expr, label=filter_field.label, @@ -88,7 +144,6 @@ def replace_csv_filters(filterset_class): exclude=filter_field.exclude, **filter_field.extra ) - elif filter_type == "range": filterset_class.base_filters[name] = RangeFilter( field_name=filter_field.field_name, diff --git a/graphene_django/tests/models.py b/graphene_django/tests/models.py index 180acc527..7b76cd378 100644 --- a/graphene_django/tests/models.py +++ b/graphene_django/tests/models.py @@ -26,7 +26,7 @@ class Film(models.Model): genre = models.CharField( max_length=2, help_text="Genre", - choices=[("do", "Documentary"), ("ot", "Other")], + choices=[("do", "Documentary"), ("ac", "Action"), ("ot", "Other")], default="ot", ) reporters = models.ManyToManyField("Reporter", related_name="films") @@ -91,8 +91,8 @@ class Meta: class Article(models.Model): headline = models.CharField(max_length=100) - pub_date = models.DateField() - pub_date_time = models.DateTimeField() + pub_date = models.DateField(auto_now_add=True) + pub_date_time = models.DateTimeField(auto_now_add=True) reporter = models.ForeignKey( Reporter, on_delete=models.CASCADE, related_name="articles" ) diff --git a/graphene_django/tests/test_query.py b/graphene_django/tests/test_query.py index 699814d2c..aabe19ceb 100644 --- a/graphene_django/tests/test_query.py +++ b/graphene_django/tests/test_query.py @@ -421,6 +421,7 @@ class Meta: interfaces = (Node,) fields = "__all__" filter_fields = ("lang",) + convert_choices_to_enum = False class Query(graphene.ObjectType): all_reporters = DjangoConnectionField(ReporterType) @@ -546,6 +547,7 @@ class Meta: interfaces = (Node,) fields = "__all__" filter_fields = ("lang", "headline") + convert_choices_to_enum = False class Query(graphene.ObjectType): all_reporters = DjangoConnectionField(ReporterType) @@ -1251,6 +1253,7 @@ class ReporterType(DjangoObjectType): class Meta: model = Reporter interfaces = (Node,) + fields = "__all__" class Query(graphene.ObjectType): all_reporters = DjangoConnectionField(ReporterType) @@ -1455,6 +1458,7 @@ class ReporterType(DjangoObjectType): class Meta: model = Reporter interfaces = (Node,) + fields = "__all__" class Query(graphene.ObjectType): all_reporters = DjangoConnectionField(ReporterType) @@ -1494,6 +1498,7 @@ class ReporterType(DjangoObjectType): class Meta: model = Reporter interfaces = (Node,) + fields = "__all__" class Query(graphene.ObjectType): all_reporters = DjangoConnectionField(ReporterType) @@ -1527,6 +1532,7 @@ class ReporterType(DjangoObjectType): class Meta: model = Reporter interfaces = (Node,) + fields = "__all__" class Query(graphene.ObjectType): all_reporters = DjangoConnectionField(ReporterType) @@ -1561,6 +1567,7 @@ class ReporterType(DjangoObjectType): class Meta: model = Reporter interfaces = (Node,) + fields = "__all__" class Query(graphene.ObjectType): all_reporters = DjangoConnectionField(ReporterType) diff --git a/graphene_django/tests/test_types.py b/graphene_django/tests/test_types.py index cb653e1cf..bde72c7a6 100644 --- a/graphene_django/tests/test_types.py +++ b/graphene_django/tests/test_types.py @@ -671,6 +671,7 @@ class Reporter(DjangoObjectType): class Meta: model = ReporterModel name = "CustomReporterName" + fields = "__all__" filter_fields = ["email"] interfaces = (Node,)