Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 62 additions & 11 deletions graphene_django/filter/tests/test_fields.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,17 @@
from datetime import datetime
from textwrap import dedent

import pytest
from django.db.models import TextField, Value
from django.db.models.functions import Concat

from graphene import Field, ObjectType, Schema, Argument, Float, Boolean, String
from graphene import Argument, Boolean, Field, Float, ObjectType, Schema, String
from graphene.relay import Node
from graphene_django import DjangoObjectType
from graphene_django.forms import GlobalIDFormField, GlobalIDMultipleChoiceField
from graphene_django.tests.models import Article, Pet, Reporter
from graphene_django.utils import DJANGO_FILTER_INSTALLED

# for annotation test
from django.db.models import TextField, Value
from django.db.models.functions import Concat

pytestmark = []

if DJANGO_FILTER_INSTALLED:
Expand Down Expand Up @@ -183,7 +182,7 @@ class context(object):
}
"""
schema = Schema(query=Query)
result = schema.execute(query, context_value=context())
result = schema.execute(query, context=context())
assert not result.errors

assert len(result.data["contextArticles"]["edges"]) == 1
Expand Down Expand Up @@ -462,15 +461,15 @@ class Meta:
class Query(ObjectType):
all_reporters = DjangoFilterConnectionField(ReporterFilterNode)

r1 = Reporter.objects.create(
Reporter.objects.create(
first_name="A test user", last_name="Last Name", email="test1@test.com"
)
r2 = Reporter.objects.create(
Reporter.objects.create(
first_name="Other test user",
last_name="Other Last Name",
email="test2@test.com",
)
r3 = Reporter.objects.create(
Reporter.objects.create(
first_name="Random", last_name="RandomLast", email="random@test.com"
)

Expand Down Expand Up @@ -638,7 +637,7 @@ def resolve_all_reporters(self, info, **args):
Reporter.objects.create(
first_name="Bob", last_name="Doe", email="bobdoe@example.com", a_choice=2
)
r = Reporter.objects.create(
Reporter.objects.create(
first_name="John", last_name="Doe", email="johndoe@example.com", a_choice=1
)

Expand Down Expand Up @@ -684,7 +683,7 @@ def resolve_all_reporters(self, info, reverse_order=False, **args):
return reporters

Reporter.objects.create(first_name="b")
r = Reporter.objects.create(first_name="a")
Reporter.objects.create(first_name="a")

schema = Schema(query=Query)
query = """
Expand Down Expand Up @@ -767,3 +766,55 @@ def resolve_all_reporters(self, info, **args):

assert not result.errors
assert result.data == expected


def test_integer_field_filter_type():
class PetType(DjangoObjectType):
class Meta:
model = Pet
interfaces = (Node,)
filter_fields = {"age": ["exact"]}
only_fields = ["age"]

class Query(ObjectType):
pets = DjangoFilterConnectionField(PetType)

schema = Schema(query=Query)

assert str(schema) == dedent(
"""\
schema {
query: Query
}

interface Node {
id: ID!
}

type PageInfo {
hasNextPage: Boolean!
hasPreviousPage: Boolean!
startCursor: String
endCursor: String
}

type PetType implements Node {
age: Int!
id: ID!
}

type PetTypeConnection {
pageInfo: PageInfo!
edges: [PetTypeEdge]!
}

type PetTypeEdge {
node: PetType
cursor: String!
}

type Query {
pets(before: String, after: String, first: Int, last: Int, age: Int): PetTypeConnection
}
"""
)
19 changes: 18 additions & 1 deletion graphene_django/filter/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,25 @@ def get_filtering_args_from_filterset(filterset_class, type):
from ..forms.converter import convert_form_field

args = {}
model = filterset_class._meta.model
for name, filter_field in six.iteritems(filterset_class.base_filters):
field_type = convert_form_field(filter_field.field).Argument()
if name in filterset_class.declared_filters:
form_field = filter_field.field
else:
field_name = name.split("__", 1)[0]
model_field = model._meta.get_field(field_name)

if hasattr(model_field, "formfield"):
form_field = model_field.formfield(
required=filter_field.extra.get("required", False)
)

# Fallback to field defined on filter if we can't get it from the
# model field
if not form_field:
form_field = filter_field.field

field_type = convert_form_field(form_field).Argument()
field_type.description = filter_field.label
args[name] = field_type

Expand Down
5 changes: 5 additions & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -38,3 +38,8 @@ omit = */tests/*

[isort]
known_first_party=graphene,graphene_django
multi_line_output=3
include_trailing_comma=True
force_grid_wrap=0
use_parentheses=True
line_length=88