/
connection.py
143 lines (117 loc) · 4.68 KB
/
connection.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
import re
from collections import Iterable, OrderedDict
from functools import partial
import six
from graphql_relay import connection_from_list
from ..types import Boolean, Int, List, String
from ..types.field import Field
from ..types.objecttype import ObjectType, ObjectTypeMeta
from ..types.options import Options
from ..types.utils import get_fields_in_type, yank_fields_from_attrs
from ..utils.is_base_type import is_base_type
from ..utils.props import props
from .node import Node, is_node
class PageInfo(ObjectType):
has_next_page = Boolean(
required=True,
name='hasNextPage',
description='When paginating forwards, are there more items?',
)
has_previous_page = Boolean(
required=True,
name='hasPreviousPage',
description='When paginating backwards, are there more items?',
)
start_cursor = String(
name='startCursor',
description='When paginating backwards, the cursor to continue.',
)
end_cursor = String(
name='endCursor',
description='When paginating forwards, the cursor to continue.',
)
class ConnectionMeta(ObjectTypeMeta):
def __new__(cls, name, bases, attrs):
# Also ensure initialization is only performed for subclasses of Model
# (excluding Model class itself).
if not is_base_type(bases, ConnectionMeta):
return type.__new__(cls, name, bases, attrs)
options = Options(
attrs.pop('Meta', None),
name=None,
description=None,
node=None,
)
options.interfaces = ()
assert options.node, 'You have to provide a node in {}.Meta'.format(cls.__name__)
assert issubclass(options.node, (Node, ObjectType)), (
'Received incompatible node "{}" for Connection {}.'
).format(options.node, name)
base_name = re.sub('Connection$', '', name)
if not options.name:
options.name = '{}Connection'.format(base_name)
edge_class = attrs.pop('Edge', None)
edge_fields = OrderedDict([
('node', Field(options.node, description='The item at the end of the edge')),
('cursor', Field(String, required=True, description='A cursor for use in pagination'))
])
edge_attrs = props(edge_class) if edge_class else OrderedDict()
extended_edge_fields = get_fields_in_type(ObjectType, edge_attrs)
edge_fields.update(extended_edge_fields)
edge_meta = type('Meta', (object, ), {
'fields': edge_fields,
'name': '{}Edge'.format(base_name)
})
yank_fields_from_attrs(edge_attrs, extended_edge_fields)
edge = type('Edge', (ObjectType,), dict(edge_attrs, Meta=edge_meta))
options.local_fields = OrderedDict([
('page_info', Field(PageInfo, name='pageInfo', required=True)),
('edges', Field(List(edge)))
])
typed_fields = get_fields_in_type(ObjectType, attrs)
options.local_fields.update(typed_fields)
options.fields = options.local_fields
yank_fields_from_attrs(attrs, typed_fields)
return type.__new__(cls, name, bases, dict(attrs, _meta=options, Edge=edge))
class Connection(six.with_metaclass(ConnectionMeta, ObjectType)):
pass
class IterableConnectionField(Field):
def __init__(self, type, *args, **kwargs):
super(IterableConnectionField, self).__init__(
type,
*args,
before=String(),
after=String(),
first=Int(),
last=Int(),
**kwargs
)
@property
def type(self):
type = super(IterableConnectionField, self).type
if is_node(type):
connection_type = type.Connection
else:
connection_type = type
assert issubclass(connection_type, Connection), (
'{} type have to be a subclass of Connection. Received "{}".'
).format(str(self), connection_type)
return connection_type
@staticmethod
def connection_resolver(resolver, connection, root, args, context, info):
iterable = resolver(root, args, context, info)
assert isinstance(iterable, Iterable), (
'Resolved value from the connection field have to be iterable. '
'Received "{}"'
).format(iterable)
connection = connection_from_list(
iterable,
args,
connection_type=connection,
edge_type=connection.Edge,
pageinfo_type=PageInfo
)
return connection
def get_resolver(self, parent_resolver):
return partial(self.connection_resolver, parent_resolver, self.type)
ConnectionField = IterableConnectionField