Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 30 additions & 13 deletions graphene_django/fields.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
from functools import partial
from collections import OrderedDict

from django.db.models.query import QuerySet

from promise import Promise

from graphene.types import Field, List
from graphene.types.argument import to_arguments
from graphene.relay import ConnectionField, PageInfo
from graphql_relay.connection.arrayconnection import connection_from_list_slice

Expand All @@ -29,7 +31,7 @@ def get_resolver(self, parent_resolver):


class DjangoConnectionField(ConnectionField):
def __init__(self, *args, **kwargs):
def __init__(self, type, *args, **kwargs):
self.on = kwargs.pop("on", False)
self.max_limit = kwargs.pop(
"max_limit", graphene_settings.RELAY_CONNECTION_MAX_LIMIT
Expand All @@ -38,7 +40,17 @@ def __init__(self, *args, **kwargs):
"enforce_first_or_last",
graphene_settings.RELAY_CONNECTION_ENFORCE_FIRST_OR_LAST,
)
super(DjangoConnectionField, self).__init__(*args, **kwargs)
super(DjangoConnectionField, self).__init__(type, *args, **kwargs)

@property
def args(self):
args = OrderedDict(self._base_args)
args.update(self.node_type.get_connection_parameters())
return to_arguments(args)

@args.setter
def args(self, args):
self._base_args = args

@property
def type(self):
Expand Down Expand Up @@ -76,14 +88,17 @@ def merge_querysets(cls, default_queryset, queryset):
return queryset & default_queryset

@classmethod
def resolve_connection(cls, connection, default_manager, args, iterable):
def resolve_connection(cls, connection, node, default_manager, args, info, iterable):
if iterable is None:
iterable = default_manager
iterable = maybe_queryset(iterable)
if isinstance(iterable, QuerySet):
if iterable is not default_manager:
default_queryset = maybe_queryset(default_manager)
iterable = cls.merge_querysets(default_queryset, iterable)
from .types import DjangoObjectType
if issubclass(node, DjangoObjectType):
iterable = node.refine_queryset(iterable, info, **args)
_len = iterable.count()
else:
_len = len(iterable)
Expand All @@ -103,15 +118,16 @@ def resolve_connection(cls, connection, default_manager, args, iterable):

@classmethod
def connection_resolver(
cls,
resolver,
connection,
default_manager,
max_limit,
enforce_first_or_last,
root,
info,
**args
cls,
resolver,
connection,
node,
default_manager,
max_limit,
enforce_first_or_last,
root,
info,
**args
):
first = args.get("first")
last = args.get("last")
Expand All @@ -135,7 +151,7 @@ def connection_resolver(
args["last"] = min(last, max_limit)

iterable = resolver(root, info, **args)
on_resolve = partial(cls.resolve_connection, connection, default_manager, args)
on_resolve = partial(cls.resolve_connection, connection, node, default_manager, args, info)

if Promise.is_thenable(iterable):
return Promise.resolve(iterable).then(on_resolve)
Expand All @@ -147,6 +163,7 @@ def get_resolver(self, parent_resolver):
self.connection_resolver,
parent_resolver,
self.type,
self.node_type,
self.get_manager(),
self.max_limit,
self.enforce_first_or_last,
Expand Down
25 changes: 14 additions & 11 deletions graphene_django/filter/fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,17 +76,18 @@ def merge_querysets(cls, default_queryset, queryset):

@classmethod
def connection_resolver(
cls,
resolver,
connection,
default_manager,
max_limit,
enforce_first_or_last,
filterset_class,
filtering_args,
root,
info,
**args
cls,
resolver,
connection,
node,
default_manager,
max_limit,
enforce_first_or_last,
filterset_class,
filtering_args,
root,
info,
**args
):
filter_kwargs = {k: v for k, v in args.items() if k in filtering_args}
qs = filterset_class(
Expand All @@ -98,6 +99,7 @@ def connection_resolver(
return super(DjangoFilterConnectionField, cls).connection_resolver(
resolver,
connection,
node,
qs,
max_limit,
enforce_first_or_last,
Expand All @@ -111,6 +113,7 @@ def get_resolver(self, parent_resolver):
self.connection_resolver,
parent_resolver,
self.type,
self.node_type,
self.get_manager(),
self.max_limit,
self.enforce_first_or_last,
Expand Down
10 changes: 9 additions & 1 deletion graphene_django/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,14 @@ def __init_subclass_with_meta__(
if not skip_registry:
registry.register(cls)

@classmethod
def refine_queryset(cls, qs, info, **kwargs):
return qs

@classmethod
def get_connection_parameters(cls):
return {}

def resolve_id(self, info):
return self.pk

Expand All @@ -130,6 +138,6 @@ def is_type_of(cls, root, info):
@classmethod
def get_node(cls, info, id):
try:
return cls._meta.model.objects.get(pk=id)
return cls.refine_queryset(cls._meta.model.objects, info).get(pk=id)
except cls._meta.model.DoesNotExist:
return None