diff --git a/graphene_django/fields.py b/graphene_django/fields.py index 1ecce454c..34d55b9e7 100644 --- a/graphene_django/fields.py +++ b/graphene_django/fields.py @@ -1,10 +1,12 @@ from functools import partial +from collections import OrderedDict from django.db.models.query import QuerySet from promise import Promise from graphene.types import Field, List +from graphene.types.argument import to_arguments from graphene.relay import ConnectionField, PageInfo from graphql_relay.connection.arrayconnection import connection_from_list_slice @@ -29,7 +31,7 @@ def get_resolver(self, parent_resolver): class DjangoConnectionField(ConnectionField): - def __init__(self, *args, **kwargs): + def __init__(self, type, *args, **kwargs): self.on = kwargs.pop("on", False) self.max_limit = kwargs.pop( "max_limit", graphene_settings.RELAY_CONNECTION_MAX_LIMIT @@ -38,7 +40,17 @@ def __init__(self, *args, **kwargs): "enforce_first_or_last", graphene_settings.RELAY_CONNECTION_ENFORCE_FIRST_OR_LAST, ) - super(DjangoConnectionField, self).__init__(*args, **kwargs) + super(DjangoConnectionField, self).__init__(type, *args, **kwargs) + + @property + def args(self): + args = OrderedDict(self._base_args) + args.update(self.node_type.get_connection_parameters()) + return to_arguments(args) + + @args.setter + def args(self, args): + self._base_args = args @property def type(self): @@ -76,7 +88,7 @@ def merge_querysets(cls, default_queryset, queryset): return queryset & default_queryset @classmethod - def resolve_connection(cls, connection, default_manager, args, iterable): + def resolve_connection(cls, connection, node, default_manager, args, info, iterable): if iterable is None: iterable = default_manager iterable = maybe_queryset(iterable) @@ -84,6 +96,9 @@ def resolve_connection(cls, connection, default_manager, args, iterable): if iterable is not default_manager: default_queryset = maybe_queryset(default_manager) iterable = cls.merge_querysets(default_queryset, iterable) + from .types import DjangoObjectType + if issubclass(node, DjangoObjectType): + iterable = node.refine_queryset(iterable, info, **args) _len = iterable.count() else: _len = len(iterable) @@ -103,15 +118,16 @@ def resolve_connection(cls, connection, default_manager, args, iterable): @classmethod def connection_resolver( - cls, - resolver, - connection, - default_manager, - max_limit, - enforce_first_or_last, - root, - info, - **args + cls, + resolver, + connection, + node, + default_manager, + max_limit, + enforce_first_or_last, + root, + info, + **args ): first = args.get("first") last = args.get("last") @@ -135,7 +151,7 @@ def connection_resolver( args["last"] = min(last, max_limit) iterable = resolver(root, info, **args) - on_resolve = partial(cls.resolve_connection, connection, default_manager, args) + on_resolve = partial(cls.resolve_connection, connection, node, default_manager, args, info) if Promise.is_thenable(iterable): return Promise.resolve(iterable).then(on_resolve) @@ -147,6 +163,7 @@ def get_resolver(self, parent_resolver): self.connection_resolver, parent_resolver, self.type, + self.node_type, self.get_manager(), self.max_limit, self.enforce_first_or_last, diff --git a/graphene_django/filter/fields.py b/graphene_django/filter/fields.py index cb4254370..340eb37a7 100644 --- a/graphene_django/filter/fields.py +++ b/graphene_django/filter/fields.py @@ -76,17 +76,18 @@ def merge_querysets(cls, default_queryset, queryset): @classmethod def connection_resolver( - cls, - resolver, - connection, - default_manager, - max_limit, - enforce_first_or_last, - filterset_class, - filtering_args, - root, - info, - **args + cls, + resolver, + connection, + node, + default_manager, + max_limit, + enforce_first_or_last, + filterset_class, + filtering_args, + root, + info, + **args ): filter_kwargs = {k: v for k, v in args.items() if k in filtering_args} qs = filterset_class( @@ -98,6 +99,7 @@ def connection_resolver( return super(DjangoFilterConnectionField, cls).connection_resolver( resolver, connection, + node, qs, max_limit, enforce_first_or_last, @@ -111,6 +113,7 @@ def get_resolver(self, parent_resolver): self.connection_resolver, parent_resolver, self.type, + self.node_type, self.get_manager(), self.max_limit, self.enforce_first_or_last, diff --git a/graphene_django/types.py b/graphene_django/types.py index aa8b5a30c..1afd0450f 100644 --- a/graphene_django/types.py +++ b/graphene_django/types.py @@ -111,6 +111,14 @@ def __init_subclass_with_meta__( if not skip_registry: registry.register(cls) + @classmethod + def refine_queryset(cls, qs, info, **kwargs): + return qs + + @classmethod + def get_connection_parameters(cls): + return {} + def resolve_id(self, info): return self.pk @@ -130,6 +138,6 @@ def is_type_of(cls, root, info): @classmethod def get_node(cls, info, id): try: - return cls._meta.model.objects.get(pk=id) + return cls.refine_queryset(cls._meta.model.objects, info).get(pk=id) except cls._meta.model.DoesNotExist: return None