diff --git a/graphene_django/filter/fields.py b/graphene_django/filter/fields.py index a46a4b738..3a98e8d18 100644 --- a/graphene_django/filter/fields.py +++ b/graphene_django/filter/fields.py @@ -1,6 +1,7 @@ from collections import OrderedDict from functools import partial +from django.core.exceptions import ValidationError from graphene.types.argument import to_arguments from ..fields import DjangoConnectionField from .utils import get_filtering_args_from_filterset, get_filterset_class @@ -59,7 +60,12 @@ def resolve_queryset( connection, iterable, info, args ) filter_kwargs = {k: v for k, v in args.items() if k in filtering_args} - return filterset_class(data=filter_kwargs, queryset=qs, request=info.context).qs + filterset = filterset_class( + data=filter_kwargs, queryset=qs, request=info.context + ) + if filterset.form.is_valid(): + return filterset.qs + raise ValidationError(filterset.form.errors.as_json()) def get_queryset_resolver(self): return partial( diff --git a/graphene_django/filter/tests/test_fields.py b/graphene_django/filter/tests/test_fields.py index 166d806fd..b8ae6fe24 100644 --- a/graphene_django/filter/tests/test_fields.py +++ b/graphene_django/filter/tests/test_fields.py @@ -400,6 +400,114 @@ def test_global_id_field_relation(): assert id_filter.field_class == GlobalIDFormField +def test_global_id_field_relation_with_filter(): + class ReporterFilterNode(DjangoObjectType): + class Meta: + model = Reporter + interfaces = (Node,) + filter_fields = ["first_name", "articles"] + + class ArticleFilterNode(DjangoObjectType): + class Meta: + model = Article + interfaces = (Node,) + filter_fields = ["headline", "reporter"] + + class Query(ObjectType): + all_reporters = DjangoFilterConnectionField(ReporterFilterNode) + all_articles = DjangoFilterConnectionField(ArticleFilterNode) + reporter = Field(ReporterFilterNode) + article = Field(ArticleFilterNode) + + r1 = Reporter.objects.create(first_name="r1", last_name="r1", email="r1@test.com") + r2 = Reporter.objects.create(first_name="r2", last_name="r2", email="r2@test.com") + Article.objects.create( + headline="a1", + pub_date=datetime.now(), + pub_date_time=datetime.now(), + reporter=r1, + editor=r1, + ) + Article.objects.create( + headline="a2", + pub_date=datetime.now(), + pub_date_time=datetime.now(), + reporter=r2, + editor=r2, + ) + + # Query articles created by the reporter `r1` + query = """ + query { + allArticles (reporter: "UmVwb3J0ZXJGaWx0ZXJOb2RlOjE=") { + edges { + node { + id + } + } + } + } + """ + schema = Schema(query=Query) + result = schema.execute(query) + assert not result.errors + # We should only get back a single article + assert len(result.data["allArticles"]["edges"]) == 1 + + +def test_global_id_field_relation_with_filter_not_valid_id(): + class ReporterFilterNode(DjangoObjectType): + class Meta: + model = Reporter + interfaces = (Node,) + filter_fields = ["first_name", "articles"] + + class ArticleFilterNode(DjangoObjectType): + class Meta: + model = Article + interfaces = (Node,) + filter_fields = ["headline", "reporter"] + + class Query(ObjectType): + all_reporters = DjangoFilterConnectionField(ReporterFilterNode) + all_articles = DjangoFilterConnectionField(ArticleFilterNode) + reporter = Field(ReporterFilterNode) + article = Field(ArticleFilterNode) + + r1 = Reporter.objects.create(first_name="r1", last_name="r1", email="r1@test.com") + r2 = Reporter.objects.create(first_name="r2", last_name="r2", email="r2@test.com") + Article.objects.create( + headline="a1", + pub_date=datetime.now(), + pub_date_time=datetime.now(), + reporter=r1, + editor=r1, + ) + Article.objects.create( + headline="a2", + pub_date=datetime.now(), + pub_date_time=datetime.now(), + reporter=r2, + editor=r2, + ) + + # Filter by the global ID that does not exist + query = """ + query { + allArticles (reporter: "fake_global_id") { + edges { + node { + id + } + } + } + } + """ + schema = Schema(query=Query) + result = schema.execute(query) + assert "Invalid ID specified." in result.errors[0].message + + def test_global_id_multiple_field_implicit(): field = DjangoFilterConnectionField(ReporterNode, fields=["pets"]) filterset_class = field.filterset_class