diff --git a/graphene_django/compat.py b/graphene_django/compat.py index fde632aa..b3d160a1 100644 --- a/graphene_django/compat.py +++ b/graphene_django/compat.py @@ -1,3 +1,6 @@ +import sys +from pathlib import PurePath + # For backwards compatibility, we import JSONField to have it available for import via # this compat module (https://github.com/graphql-python/graphene-django/issues/1428). # Django's JSONField is available in Django 3.2+ (the minimum version we support) @@ -19,4 +22,23 @@ def __init__(self, *args, **kwargs): RangeField, ) except ImportError: - IntegerRangeField, ArrayField, HStoreField, RangeField = (MissingType,) * 4 + IntegerRangeField, HStoreField, RangeField = (MissingType,) * 3 + + # For unit tests we fake ArrayField using JSONFields + if any( + PurePath(sys.argv[0]).match(p) + for p in [ + "**/pytest", + "**/py.test", + "**/pytest/__main__.py", + ] + ): + + class ArrayField(JSONField): + def __init__(self, *args, **kwargs): + if len(args) > 0: + self.base_field = args[0] + super().__init__(**kwargs) + + else: + ArrayField = MissingType diff --git a/graphene_django/filter/filters/array_filter.py b/graphene_django/filter/filters/array_filter.py index b6f4808e..a2fccda3 100644 --- a/graphene_django/filter/filters/array_filter.py +++ b/graphene_django/filter/filters/array_filter.py @@ -1,13 +1,36 @@ from django_filters.constants import EMPTY_VALUES +from django_filters.filters import FilterMethod from .typed_filter import TypedFilter +class ArrayFilterMethod(FilterMethod): + def __call__(self, qs, value): + if value is None: + return qs + return self.method(qs, self.f.field_name, value) + + class ArrayFilter(TypedFilter): """ Filter made for PostgreSQL ArrayField. """ + @TypedFilter.method.setter + def method(self, value): + """ + Override method setter so that in case a custom `method` is provided + (see documentation https://django-filter.readthedocs.io/en/stable/ref/filters.html#method), + it doesn't fall back to checking if the value is in `EMPTY_VALUES` (from the `__call__` method + of the `FilterMethod` class) and instead use our ArrayFilterMethod that consider empty lists as values. + + Indeed when providing a `method` the `filter` method below is overridden and replaced by `FilterMethod(self)` + which means that the validation of the empty value is made by the `FilterMethod.__call__` method instead. + """ + TypedFilter.method.fset(self, value) + if value is not None: + self.filter = ArrayFilterMethod(self) + def filter(self, qs, value): """ Override the default filter class to check first whether the list is diff --git a/graphene_django/filter/filters/list_filter.py b/graphene_django/filter/filters/list_filter.py index 6689877c..db91409c 100644 --- a/graphene_django/filter/filters/list_filter.py +++ b/graphene_django/filter/filters/list_filter.py @@ -1,12 +1,36 @@ +from django_filters.filters import FilterMethod + from .typed_filter import TypedFilter +class ListFilterMethod(FilterMethod): + def __call__(self, qs, value): + if value is None: + return qs + return self.method(qs, self.f.field_name, value) + + class ListFilter(TypedFilter): """ Filter that takes a list of value as input. It is for example used for `__in` filters. """ + @TypedFilter.method.setter + def method(self, value): + """ + Override method setter so that in case a custom `method` is provided + (see documentation https://django-filter.readthedocs.io/en/stable/ref/filters.html#method), + it doesn't fall back to checking if the value is in `EMPTY_VALUES` (from the `__call__` method + of the `FilterMethod` class) and instead use our ListFilterMethod that consider empty lists as values. + + Indeed when providing a `method` the `filter` method below is overridden and replaced by `FilterMethod(self)` + which means that the validation of the empty value is made by the `FilterMethod.__call__` method instead. + """ + TypedFilter.method.fset(self, value) + if value is not None: + self.filter = ListFilterMethod(self) + def filter(self, qs, value): """ Override the default filter class to check first whether the list is diff --git a/graphene_django/filter/tests/conftest.py b/graphene_django/filter/tests/conftest.py index 1556f545..9f5d3666 100644 --- a/graphene_django/filter/tests/conftest.py +++ b/graphene_django/filter/tests/conftest.py @@ -1,4 +1,4 @@ -from unittest.mock import MagicMock +from functools import reduce import pytest from django.db import models @@ -25,15 +25,15 @@ ) -STORE = {"events": []} - - class Event(models.Model): name = models.CharField(max_length=50) tags = ArrayField(models.CharField(max_length=50)) tag_ids = ArrayField(models.IntegerField()) random_field = ArrayField(models.BooleanField()) + def __repr__(self): + return f"Event [{self.name}]" + @pytest.fixture def EventFilterSet(): @@ -48,6 +48,14 @@ class Meta: tags__contains = ArrayFilter(field_name="tags", lookup_expr="contains") tags__overlap = ArrayFilter(field_name="tags", lookup_expr="overlap") tags = ArrayFilter(field_name="tags", lookup_expr="exact") + tags__len = ArrayFilter( + field_name="tags", lookup_expr="len", input_type=graphene.Int + ) + tags__len__in = ArrayFilter( + field_name="tags", + method="tags__len__in_filter", + input_type=graphene.List(graphene.Int), + ) # Those are actually not usable and only to check type declarations tags_ids__contains = ArrayFilter(field_name="tag_ids", lookup_expr="contains") @@ -61,6 +69,14 @@ class Meta: ) random_field = ArrayFilter(field_name="random_field", lookup_expr="exact") + def tags__len__in_filter(self, queryset, _name, value): + if not value: + return queryset.none() + return reduce( + lambda q1, q2: q1.union(q2), + [queryset.filter(tags__len=v) for v in value], + ).distinct() + return EventFilterSet @@ -83,68 +99,94 @@ def Query(EventType): we are running unit tests in sqlite which does not have ArrayFields. """ + events = [ + Event(name="Live Show", tags=["concert", "music", "rock"]), + Event(name="Musical", tags=["movie", "music"]), + Event(name="Ballet", tags=["concert", "dance"]), + Event(name="Speech", tags=[]), + ] + class Query(graphene.ObjectType): events = DjangoFilterConnectionField(EventType) def resolve_events(self, info, **kwargs): - events = [ - Event(name="Live Show", tags=["concert", "music", "rock"]), - Event(name="Musical", tags=["movie", "music"]), - Event(name="Ballet", tags=["concert", "dance"]), - Event(name="Speech", tags=[]), - ] - - STORE["events"] = events - - m_queryset = MagicMock(spec=QuerySet) - m_queryset.model = Event - - def filter_events(**kwargs): - if "tags__contains" in kwargs: - STORE["events"] = list( - filter( - lambda e: set(kwargs["tags__contains"]).issubset( - set(e.tags) - ), - STORE["events"], + class FakeQuerySet(QuerySet): + def __init__(self, model=None): + self.model = Event + self.__store = list(events) + + def all(self): + return self + + def filter(self, **kwargs): + queryset = FakeQuerySet() + queryset.__store = list(self.__store) + if "tags__contains" in kwargs: + queryset.__store = list( + filter( + lambda e: set(kwargs["tags__contains"]).issubset( + set(e.tags) + ), + queryset.__store, + ) + ) + if "tags__overlap" in kwargs: + queryset.__store = list( + filter( + lambda e: not set(kwargs["tags__overlap"]).isdisjoint( + set(e.tags) + ), + queryset.__store, + ) ) - ) - if "tags__overlap" in kwargs: - STORE["events"] = list( - filter( - lambda e: not set(kwargs["tags__overlap"]).isdisjoint( - set(e.tags) - ), - STORE["events"], + if "tags__exact" in kwargs: + queryset.__store = list( + filter( + lambda e: set(kwargs["tags__exact"]) == set(e.tags), + queryset.__store, + ) ) - ) - if "tags__exact" in kwargs: - STORE["events"] = list( - filter( - lambda e: set(kwargs["tags__exact"]) == set(e.tags), - STORE["events"], + if "tags__len" in kwargs: + queryset.__store = list( + filter( + lambda e: len(e.tags) == kwargs["tags__len"], + queryset.__store, + ) ) - ) + return queryset + + def union(self, *args): + queryset = FakeQuerySet() + queryset.__store = self.__store + for arg in args: + queryset.__store += arg.__store + return queryset - def mock_queryset_filter(*args, **kwargs): - filter_events(**kwargs) - return m_queryset + def none(self): + queryset = FakeQuerySet() + queryset.__store = [] + return queryset - def mock_queryset_none(*args, **kwargs): - STORE["events"] = [] - return m_queryset + def count(self): + return len(self.__store) - def mock_queryset_count(*args, **kwargs): - return len(STORE["events"]) + def distinct(self): + queryset = FakeQuerySet() + queryset.__store = [] + for event in self.__store: + if event not in queryset.__store: + queryset.__store.append(event) + queryset.__store = sorted(queryset.__store, key=lambda e: e.name) + return queryset - m_queryset.all.return_value = m_queryset - m_queryset.filter.side_effect = mock_queryset_filter - m_queryset.none.side_effect = mock_queryset_none - m_queryset.count.side_effect = mock_queryset_count - m_queryset.__getitem__.side_effect = lambda index: STORE[ - "events" - ].__getitem__(index) + def __getitem__(self, index): + return self.__store[index] - return m_queryset + return FakeQuerySet() return Query + + +@pytest.fixture +def schema(Query): + return graphene.Schema(query=Query) diff --git a/graphene_django/filter/tests/test_array_field_contains_filter.py b/graphene_django/filter/tests/test_array_field_contains_filter.py index 4144614c..52a9f241 100644 --- a/graphene_django/filter/tests/test_array_field_contains_filter.py +++ b/graphene_django/filter/tests/test_array_field_contains_filter.py @@ -1,18 +1,14 @@ import pytest -from graphene import Schema - from ...compat import ArrayField, MissingType @pytest.mark.skipif(ArrayField is MissingType, reason="ArrayField should exist") -def test_array_field_contains_multiple(Query): +def test_array_field_contains_multiple(schema): """ Test contains filter on a array field of string. """ - schema = Schema(query=Query) - query = """ query { events (tags_Contains: ["concert", "music"]) { @@ -32,13 +28,11 @@ def test_array_field_contains_multiple(Query): @pytest.mark.skipif(ArrayField is MissingType, reason="ArrayField should exist") -def test_array_field_contains_one(Query): +def test_array_field_contains_one(schema): """ Test contains filter on a array field of string. """ - schema = Schema(query=Query) - query = """ query { events (tags_Contains: ["music"]) { @@ -59,13 +53,11 @@ def test_array_field_contains_one(Query): @pytest.mark.skipif(ArrayField is MissingType, reason="ArrayField should exist") -def test_array_field_contains_empty_list(Query): +def test_array_field_contains_empty_list(schema): """ Test contains filter on a array field of string. """ - schema = Schema(query=Query) - query = """ query { events (tags_Contains: []) { diff --git a/graphene_django/filter/tests/test_array_field_custom_filter.py b/graphene_django/filter/tests/test_array_field_custom_filter.py new file mode 100644 index 00000000..3fdb992a --- /dev/null +++ b/graphene_django/filter/tests/test_array_field_custom_filter.py @@ -0,0 +1,186 @@ +import pytest + +from ...compat import ArrayField, MissingType + + +@pytest.mark.skipif(ArrayField is MissingType, reason="ArrayField should exist") +def test_array_field_len_filter(schema): + query = """ + query { + events (tags_Len: 2) { + edges { + node { + name + } + } + } + } + """ + result = schema.execute(query) + assert not result.errors + assert result.data["events"]["edges"] == [ + {"node": {"name": "Musical"}}, + {"node": {"name": "Ballet"}}, + ] + + query = """ + query { + events (tags_Len: 0) { + edges { + node { + name + } + } + } + } + """ + result = schema.execute(query) + assert not result.errors + assert result.data["events"]["edges"] == [ + {"node": {"name": "Speech"}}, + ] + + query = """ + query { + events (tags_Len: 10) { + edges { + node { + name + } + } + } + } + """ + result = schema.execute(query) + assert not result.errors + assert result.data["events"]["edges"] == [] + + query = """ + query { + events (tags_Len: "2") { + edges { + node { + name + } + } + } + } + """ + result = schema.execute(query) + assert len(result.errors) == 1 + assert result.errors[0].message == 'Int cannot represent non-integer value: "2"' + + query = """ + query { + events (tags_Len: True) { + edges { + node { + name + } + } + } + } + """ + result = schema.execute(query) + assert len(result.errors) == 1 + assert result.errors[0].message == "Int cannot represent non-integer value: True" + + +@pytest.mark.skipif(ArrayField is MissingType, reason="ArrayField should exist") +def test_array_field_custom_filter(schema): + query = """ + query { + events (tags_Len_In: 2) { + edges { + node { + name + } + } + } + } + """ + result = schema.execute(query) + assert not result.errors + assert result.data["events"]["edges"] == [ + {"node": {"name": "Ballet"}}, + {"node": {"name": "Musical"}}, + ] + + query = """ + query { + events (tags_Len_In: [0, 2]) { + edges { + node { + name + } + } + } + } + """ + result = schema.execute(query) + assert not result.errors + assert result.data["events"]["edges"] == [ + {"node": {"name": "Ballet"}}, + {"node": {"name": "Musical"}}, + {"node": {"name": "Speech"}}, + ] + + query = """ + query { + events (tags_Len_In: [10]) { + edges { + node { + name + } + } + } + } + """ + result = schema.execute(query) + assert not result.errors + assert result.data["events"]["edges"] == [] + + query = """ + query { + events (tags_Len_In: []) { + edges { + node { + name + } + } + } + } + """ + result = schema.execute(query) + assert not result.errors + assert result.data["events"]["edges"] == [] + + query = """ + query { + events (tags_Len_In: "12") { + edges { + node { + name + } + } + } + } + """ + result = schema.execute(query) + assert len(result.errors) == 1 + assert result.errors[0].message == 'Int cannot represent non-integer value: "12"' + + query = """ + query { + events (tags_Len_In: True) { + edges { + node { + name + } + } + } + } + """ + result = schema.execute(query) + assert len(result.errors) == 1 + assert result.errors[0].message == "Int cannot represent non-integer value: True" diff --git a/graphene_django/filter/tests/test_array_field_exact_filter.py b/graphene_django/filter/tests/test_array_field_exact_filter.py index 10e32ef7..5cba1937 100644 --- a/graphene_django/filter/tests/test_array_field_exact_filter.py +++ b/graphene_django/filter/tests/test_array_field_exact_filter.py @@ -1,18 +1,14 @@ import pytest -from graphene import Schema - from ...compat import ArrayField, MissingType @pytest.mark.skipif(ArrayField is MissingType, reason="ArrayField should exist") -def test_array_field_exact_no_match(Query): +def test_array_field_exact_no_match(schema): """ Test exact filter on a array field of string. """ - schema = Schema(query=Query) - query = """ query { events (tags: ["concert", "music"]) { @@ -30,13 +26,11 @@ def test_array_field_exact_no_match(Query): @pytest.mark.skipif(ArrayField is MissingType, reason="ArrayField should exist") -def test_array_field_exact_match(Query): +def test_array_field_exact_match(schema): """ Test exact filter on a array field of string. """ - schema = Schema(query=Query) - query = """ query { events (tags: ["movie", "music"]) { @@ -56,13 +50,11 @@ def test_array_field_exact_match(Query): @pytest.mark.skipif(ArrayField is MissingType, reason="ArrayField should exist") -def test_array_field_exact_empty_list(Query): +def test_array_field_exact_empty_list(schema): """ Test exact filter on a array field of string. """ - schema = Schema(query=Query) - query = """ query { events (tags: []) { @@ -82,11 +74,10 @@ def test_array_field_exact_empty_list(Query): @pytest.mark.skipif(ArrayField is MissingType, reason="ArrayField should exist") -def test_array_field_filter_schema_type(Query): +def test_array_field_filter_schema_type(schema): """ Check that the type in the filter is an array field like on the object type. """ - schema = Schema(query=Query) schema_str = str(schema) assert ( @@ -112,6 +103,8 @@ def test_array_field_filter_schema_type(Query): "tags_Contains": "[String!]", "tags_Overlap": "[String!]", "tags": "[String!]", + "tags_Len": "Int", + "tags_Len_In": "[Int]", "tagsIds_Contains": "[Int!]", "tagsIds_Overlap": "[Int!]", "tagsIds": "[Int!]", diff --git a/graphene_django/filter/tests/test_array_field_overlap_filter.py b/graphene_django/filter/tests/test_array_field_overlap_filter.py index 5ce1576b..95d339c3 100644 --- a/graphene_django/filter/tests/test_array_field_overlap_filter.py +++ b/graphene_django/filter/tests/test_array_field_overlap_filter.py @@ -1,18 +1,14 @@ import pytest -from graphene import Schema - from ...compat import ArrayField, MissingType @pytest.mark.skipif(ArrayField is MissingType, reason="ArrayField should exist") -def test_array_field_overlap_multiple(Query): +def test_array_field_overlap_multiple(schema): """ Test overlap filter on a array field of string. """ - schema = Schema(query=Query) - query = """ query { events (tags_Overlap: ["concert", "music"]) { @@ -34,13 +30,11 @@ def test_array_field_overlap_multiple(Query): @pytest.mark.skipif(ArrayField is MissingType, reason="ArrayField should exist") -def test_array_field_overlap_one(Query): +def test_array_field_overlap_one(schema): """ Test overlap filter on a array field of string. """ - schema = Schema(query=Query) - query = """ query { events (tags_Overlap: ["music"]) { @@ -61,13 +55,11 @@ def test_array_field_overlap_one(Query): @pytest.mark.skipif(ArrayField is MissingType, reason="ArrayField should exist") -def test_array_field_overlap_empty_list(Query): +def test_array_field_overlap_empty_list(schema): """ Test overlap filter on a array field of string. """ - schema = Schema(query=Query) - query = """ query { events (tags_Overlap: []) { diff --git a/graphene_django/filter/tests/test_typed_filter.py b/graphene_django/filter/tests/test_typed_filter.py index 084affad..b155385d 100644 --- a/graphene_django/filter/tests/test_typed_filter.py +++ b/graphene_django/filter/tests/test_typed_filter.py @@ -1,4 +1,8 @@ +import operator +from functools import reduce + import pytest +from django.db.models import Q from django_filters import FilterSet import graphene @@ -44,6 +48,10 @@ class Meta: only_first = TypedFilter( input_type=graphene.Boolean, method="only_first_filter" ) + headline_search = ListFilter( + method="headline_search_filter", + input_type=graphene.List(graphene.String), + ) def first_n_filter(self, queryset, _name, value): return queryset[:value] @@ -54,6 +62,13 @@ def only_first_filter(self, queryset, _name, value): else: return queryset + def headline_search_filter(self, queryset, _name, value): + if not value: + return queryset.none() + return queryset.filter( + reduce(operator.or_, [Q(headline__icontains=v) for v in value]) + ) + class ArticleType(DjangoObjectType): class Meta: model = Article @@ -87,6 +102,7 @@ def test_typed_filter_schema(schema): "lang_InStr": "[String]", "firstN": "Int", "onlyFirst": "Boolean", + "headlineSearch": "[String]", } all_articles_filters = ( @@ -104,6 +120,40 @@ def test_typed_filters_work(schema): Article.objects.create(headline="A", reporter=reporter, editor=reporter, lang="es") Article.objects.create(headline="B", reporter=reporter, editor=reporter, lang="es") Article.objects.create(headline="C", reporter=reporter, editor=reporter, lang="en") + Article.objects.create(headline="AB", reporter=reporter, editor=reporter, lang="es") + + query = 'query { articles (lang_Contains: "n") { edges { node { headline } } } }' + + result = schema.execute(query) + assert not result.errors + assert result.data["articles"]["edges"] == [ + {"node": {"headline": "C"}}, + ] + + query = "query { articles (firstN: 2) { edges { node { headline } } } }" + + result = schema.execute(query) + assert not result.errors + assert result.data["articles"]["edges"] == [ + {"node": {"headline": "A"}}, + {"node": {"headline": "AB"}}, + ] + + query = "query { articles (onlyFirst: true) { edges { node { headline } } } }" + + result = schema.execute(query) + assert not result.errors + assert result.data["articles"]["edges"] == [ + {"node": {"headline": "A"}}, + ] + + +def test_list_filters_work(schema): + reporter = Reporter.objects.create(first_name="John", last_name="Doe", email="") + Article.objects.create(headline="A", reporter=reporter, editor=reporter, lang="es") + Article.objects.create(headline="B", reporter=reporter, editor=reporter, lang="es") + Article.objects.create(headline="C", reporter=reporter, editor=reporter, lang="en") + Article.objects.create(headline="AB", reporter=reporter, editor=reporter, lang="es") query = "query { articles (lang_In: [ES]) { edges { node { headline } } } }" @@ -111,6 +161,7 @@ def test_typed_filters_work(schema): assert not result.errors assert result.data["articles"]["edges"] == [ {"node": {"headline": "A"}}, + {"node": {"headline": "AB"}}, {"node": {"headline": "B"}}, ] @@ -120,30 +171,61 @@ def test_typed_filters_work(schema): assert not result.errors assert result.data["articles"]["edges"] == [ {"node": {"headline": "A"}}, + {"node": {"headline": "AB"}}, {"node": {"headline": "B"}}, ] - query = 'query { articles (lang_Contains: "n") { edges { node { headline } } } }' + query = "query { articles (lang_InStr: []) { edges { node { headline } } } }" + + result = schema.execute(query) + assert not result.errors + assert result.data["articles"]["edges"] == [] + + query = "query { articles (lang_InStr: null) { edges { node { headline } } } }" result = schema.execute(query) assert not result.errors assert result.data["articles"]["edges"] == [ + {"node": {"headline": "A"}}, + {"node": {"headline": "AB"}}, + {"node": {"headline": "B"}}, {"node": {"headline": "C"}}, ] - query = "query { articles (firstN: 2) { edges { node { headline } } } }" + query = 'query { articles (headlineSearch: ["a", "B"]) { edges { node { headline } } } }' result = schema.execute(query) assert not result.errors assert result.data["articles"]["edges"] == [ {"node": {"headline": "A"}}, + {"node": {"headline": "AB"}}, {"node": {"headline": "B"}}, ] - query = "query { articles (onlyFirst: true) { edges { node { headline } } } }" + query = "query { articles (headlineSearch: []) { edges { node { headline } } } }" + + result = schema.execute(query) + assert not result.errors + assert result.data["articles"]["edges"] == [] + + query = "query { articles (headlineSearch: null) { edges { node { headline } } } }" + + result = schema.execute(query) + assert not result.errors + assert result.data["articles"]["edges"] == [ + {"node": {"headline": "A"}}, + {"node": {"headline": "AB"}}, + {"node": {"headline": "B"}}, + {"node": {"headline": "C"}}, + ] + + query = 'query { articles (headlineSearch: [""]) { edges { node { headline } } } }' result = schema.execute(query) assert not result.errors assert result.data["articles"]["edges"] == [ {"node": {"headline": "A"}}, + {"node": {"headline": "AB"}}, + {"node": {"headline": "B"}}, + {"node": {"headline": "C"}}, ]