forked from samsarahq/thunder
/
batch.go
387 lines (343 loc) · 12.7 KB
/
batch.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
package schemabuilder
import (
"context"
"fmt"
"reflect"
"github.com/guad/thunder/batch"
"github.com/guad/thunder/graphql"
)
// buildBatchFunction corresponds to buildFunction for a batchFieldFunc
func (sb *schemaBuilder) buildBatchFunctionWithFallback(typ reflect.Type, m *method) (*graphql.Field, error) {
fallbackField, fallbackFuncCtx, err := sb.buildFunctionAndFuncCtx(typ, &method{
Fn: m.BatchArgs.FallbackFunc,
MarkedNonNullable: m.MarkedNonNullable,
// We don't want to accidentally make the fallback non-expensive
// if someone forgets to pass the Expensive option to `BatchFieldFuncWithFallback`.
// It's safe to assume the fallback is expensive, because why else
// would someone bother batching it?
Expensive: true,
})
if err != nil {
return nil, err
}
batchField, batchFuncCtx, err := sb.buildBatchFunctionAndFuncCtx(typ, m)
if err != nil {
return nil, err
}
if fallbackFuncCtx.hasContext != batchFuncCtx.hasContext ||
!fallbackFuncCtx.hasSource || // Batch func always has a source.
fallbackFuncCtx.hasArgs != batchFuncCtx.hasArgs ||
fallbackFuncCtx.hasSelectionSet != batchFuncCtx.hasSelectionSet ||
fallbackFuncCtx.hasError != batchFuncCtx.hasError ||
fallbackFuncCtx.hasRet != batchFuncCtx.hasRet {
return nil, fmt.Errorf("batch and fallback function signatures did not match")
}
if fallbackField.Type.String() != batchField.Type.String() {
return nil, fmt.Errorf("batch and fallback graphql return types did not match: Batch(%v) Fallback(%v)", batchField.Type, fallbackField.Type)
}
if len(fallbackField.Args) != len(batchField.Args) {
return nil, fmt.Errorf("batch and fallback arg type did not match: Batch(%v) Fallback(%v)", batchField.Args, fallbackField.Args)
}
for key, fallbackTyp := range fallbackField.Args {
if batchType, ok := batchField.Args[key]; !ok || fallbackTyp.String() != batchType.String() {
return nil, fmt.Errorf("batch and fallback func arg types did not match: Batch(%v) Fallback(%v)", batchType, fallbackTyp)
}
}
if m.BatchArgs.ShouldUseBatchFunc == nil {
return nil, fmt.Errorf("batch function requires fallback check function (got nil)")
}
batchField.UseBatchFunc = m.BatchArgs.ShouldUseBatchFunc
batchField.Resolve = fallbackField.Resolve
return batchField, nil
}
func (sb *schemaBuilder) buildBatchFunction(typ reflect.Type, m *method) (*graphql.Field, error) {
batchField, _, err := sb.buildBatchFunctionAndFuncCtx(typ, m)
if err != nil {
return nil, err
}
batchField.UseBatchFunc = func(context.Context) bool {
return true
}
return batchField, nil
}
// buildBatchFunction corresponds to buildFunction for a batchFieldFunc
func (sb *schemaBuilder) buildBatchFunctionAndFuncCtx(typ reflect.Type, m *method) (*graphql.Field, *batchFuncContext, error) {
funcCtx := &batchFuncContext{parentTyp: typ}
if typ.Kind() == reflect.Ptr {
return nil, nil, fmt.Errorf("source-type of buildBatchFunction cannot be a pointer (got: %v)", typ)
}
callableFunc, err := funcCtx.getFuncVal(m)
if err != nil {
return nil, nil, err
}
in := funcCtx.getFuncInputTypes()
if len(in) == 0 {
return nil, nil, fmt.Errorf("batch Field funcs require at least one input field")
}
in = funcCtx.consumeContext(in)
in, err = funcCtx.consumeRequiredSourceBatch(in)
if err != nil {
return nil, nil, err
}
argParser, args, in, err := funcCtx.consumeArgs(sb, in)
if err != nil {
return nil, nil, err
}
in = funcCtx.consumeSelectionSet(in)
// We have succeeded if no arguments remain.
if len(in) != 0 {
return nil, nil, fmt.Errorf("%s arguments should be [context,]map[int][*]%s[, args][, selectionSet]", funcCtx.funcType, typ)
}
out := funcCtx.getFuncOutputTypes()
retType, out, err := funcCtx.consumeReturnValue(m, sb, out)
if err != nil {
return nil, nil, err
}
out = funcCtx.consumeReturnError(out)
if len(out) > 0 {
return nil, nil, fmt.Errorf("%s return should be [map[int]<Type>][,error]", funcCtx.funcType)
}
batchExecFunc := func(ctx context.Context, sources []interface{}, funcRawArgs interface{}, selectionSet *graphql.SelectionSet) ([]interface{}, error) {
// Set up function arguments.
funcInputArgs, idxValues := funcCtx.prepareResolveArgs(sources, funcRawArgs, ctx, selectionSet)
// Call the function.
funcOutputArgs := callableFunc.Call(funcInputArgs)
return funcCtx.extractResultsAndErr(funcOutputArgs, idxValues, retType)
}
return &graphql.Field{
BatchResolver: batchExecFunc,
Batch: true,
External: true,
Args: args,
Type: retType,
ParseArguments: argParser.Parse,
Expensive: m.Expensive,
NumParallelInvocationsFunc: m.ConcurrencyArgs.numParallelInvocationsFunc,
}, funcCtx, nil
}
// funcContext is used to parse the function signature in buildFunction.
type batchFuncContext struct {
hasContext bool
hasArgs bool
hasSelectionSet bool
hasRet bool
hasError bool
enforceNoNilResps bool
funcType reflect.Type
batchMapType reflect.Type
isPtrFunc bool
parentTyp reflect.Type
}
// getFuncVal returns a reflect.Value of an executable function.
func (funcCtx *batchFuncContext) getFuncVal(m *method) (reflect.Value, error) {
fun := reflect.ValueOf(m.Fn)
if fun.Kind() != reflect.Func {
return fun, fmt.Errorf("fun must be func, not %s", fun)
}
funcCtx.funcType = fun.Type()
return fun, nil
}
// getFuncInputTypes returns the input arguments for the function we're
// representing.
func (funcCtx *batchFuncContext) getFuncInputTypes() []reflect.Type {
in := make([]reflect.Type, 0, funcCtx.funcType.NumIn())
for i := 0; i < funcCtx.funcType.NumIn(); i++ {
in = append(in, funcCtx.funcType.In(i))
}
return in
}
// consumeContext reads in the input parameters for the provided
// function and determines whether the function has a Context input parameter.
// It returns the input types without the context parameter if it was there.
func (funcCtx *batchFuncContext) consumeContext(in []reflect.Type) []reflect.Type {
if len(in) > 0 && in[0] == contextType {
funcCtx.hasContext = true
in = in[1:]
}
return in
}
// consumeRequiredSourceBatch reads in the input parameters for the provided
// function and guarantees that the input parameters include a batch of the
// parent type (map[int]*ParentObject). If we don't have the batch we return an
// error because the function is invalid.
func (funcCtx *batchFuncContext) consumeRequiredSourceBatch(in []reflect.Type) ([]reflect.Type, error) {
if len(in) == 0 {
return nil, fmt.Errorf("requires batch source input parameter for func")
}
inType := in[0]
in = in[1:]
parentPtrType := reflect.PtrTo(funcCtx.parentTyp)
if inType.Kind() != reflect.Map ||
!isBatchIndexType(inType.Key()) ||
(inType.Elem() != parentPtrType && inType.Elem() != funcCtx.parentTyp) {
return nil, fmt.Errorf(
"invalid source batch type, expected one of map[batch.Index]*%s or map[batch.Index]%s, but got %s",
funcCtx.parentTyp.String(),
funcCtx.parentTyp.String(),
inType.String(),
)
}
funcCtx.isPtrFunc = inType.Elem() == parentPtrType
funcCtx.batchMapType = inType
return in, nil
}
// consumeArgs reads the args parameter if it is there and returns an argParser,
// argTypeMap and the filtered input parameters.
func (funcCtx *batchFuncContext) consumeArgs(sb *schemaBuilder, in []reflect.Type) (*argParser, map[string]graphql.Type, []reflect.Type, error) {
if len(in) == 0 || in[0] == selectionSetType {
return nil, nil, in, nil
}
inType := in[0]
in = in[1:]
argParser, argType, err := sb.makeStructParser(inType)
if err != nil {
return nil, nil, in, fmt.Errorf("attempted to parse %s as arguments struct, but failed: %s", inType.Name(), err.Error())
}
inputObject, ok := argType.(*graphql.InputObject)
if !ok {
return nil, nil, nil, fmt.Errorf("%s's args should be an object", funcCtx.funcType)
}
args := make(map[string]graphql.Type, len(inputObject.InputFields))
for name, typ := range inputObject.InputFields {
args[name] = typ
}
funcCtx.hasArgs = true
return argParser, args, in, nil
}
// consumeSelectionSet reads the input parameters and will pop off the
// selectionSet type if we detect it in the input fields. Check out
// graphql.SelectionSet for more infomation about selection sets.
func (funcCtx *batchFuncContext) consumeSelectionSet(in []reflect.Type) []reflect.Type {
if len(in) > 0 && in[0] == selectionSetType {
in = in[1:]
funcCtx.hasSelectionSet = true
}
return in
}
func (funcCtx *batchFuncContext) getFuncOutputTypes() []reflect.Type {
out := make([]reflect.Type, 0, funcCtx.funcType.NumOut())
for i := 0; i < funcCtx.funcType.NumOut(); i++ {
out = append(out, funcCtx.funcType.Out(i))
}
return out
}
// consumeReturnValue consumes the function output's response value if it exists
// and validates that the response is a proper batch type.
func (funcCtx *batchFuncContext) consumeReturnValue(m *method, sb *schemaBuilder, out []reflect.Type) (graphql.Type, []reflect.Type, error) {
if len(out) == 0 || out[0] == errType {
retType, err := sb.getType(reflect.TypeOf(true))
if err != nil {
return nil, nil, err
}
return retType, out, nil
}
outType := out[0]
out = out[1:]
if outType.Kind() != reflect.Map ||
!isBatchIndexType(outType.Key()) {
return nil, nil, fmt.Errorf(
"invalid response batch type, expected map[batch.Index]<Type>, but got %s",
outType.String(),
)
}
retType, err := sb.getType(outType.Elem())
if err != nil {
return nil, nil, err
}
if nonNull, ok := retType.(*graphql.NonNull); ok {
if _, isList := nonNull.Type.(*graphql.List); !isList {
// Batch functions don't support NonNull responses by default unless they
// are lists we can fill with zero length values.
retType = nonNull.Type
}
}
if m.MarkedNonNullable {
funcCtx.enforceNoNilResps = true
if _, ok := retType.(*graphql.NonNull); !ok {
retType = &graphql.NonNull{Type: retType}
}
}
funcCtx.hasRet = true
return retType, out, nil
}
var batchIndexTyp reflect.Type
func init() {
var batchIndexPointer *batch.Index
batchIndexTyp = reflect.TypeOf(batchIndexPointer).Elem()
}
func isBatchIndexType(t reflect.Type) bool {
return t == batchIndexTyp
}
// consumeReturnValue consumes the function output's error type if it exists.
func (funcCtx *batchFuncContext) consumeReturnError(out []reflect.Type) []reflect.Type {
if len(out) > 0 && out[0] == errType {
funcCtx.hasError = true
out = out[1:]
}
return out
}
// prepareResolveArgs converts the provided source, args and context into the
// required list of reflect.Value types that the function needs to be called.
func (funcCtx *batchFuncContext) prepareResolveArgs(sources []interface{}, args interface{}, ctx context.Context, selectionSet *graphql.SelectionSet) (in []reflect.Value, idxValues []reflect.Value) {
in = make([]reflect.Value, 0, funcCtx.funcType.NumIn())
if funcCtx.hasContext {
in = append(in, reflect.ValueOf(ctx))
}
batchMap := reflect.MakeMapWithSize(funcCtx.batchMapType, len(sources))
idxValues = make([]reflect.Value, len(sources))
for idx, source := range sources {
idxVal := idx
sourceValue := reflect.ValueOf(source)
ptrSource := sourceValue.Kind() == reflect.Ptr
idxValues[idxVal] = reflect.ValueOf(batch.NewIndex(idxVal))
switch {
case ptrSource && !funcCtx.isPtrFunc:
batchMap.SetMapIndex(idxValues[idxVal], sourceValue.Elem())
case !ptrSource && funcCtx.isPtrFunc:
copyPtr := reflect.New(funcCtx.parentTyp)
copyPtr.Elem().Set(sourceValue)
batchMap.SetMapIndex(idxValues[idxVal], copyPtr)
default:
batchMap.SetMapIndex(idxValues[idxVal], sourceValue)
}
}
in = append(in, batchMap)
// Set up other arguments.
if funcCtx.hasArgs {
in = append(in, reflect.ValueOf(args))
}
if funcCtx.hasSelectionSet {
in = append(in, reflect.ValueOf(selectionSet))
}
return in, idxValues
}
// extractResultsAndErr converts the response from calling the function into
// the expected type for the response object (as opposed to a reflect.Value).
// It also handles reading whether the function ended with errors.
func (funcCtx *batchFuncContext) extractResultsAndErr(out []reflect.Value, idxValues []reflect.Value, retType graphql.Type) ([]interface{}, error) {
if funcCtx.hasError {
if err := out[len(out)-1]; !err.IsNil() {
return nil, err.Interface().(error)
}
}
if !funcCtx.hasRet {
res := make([]interface{}, len(idxValues))
for i := 0; i < len(idxValues); i++ {
res[i] = true
}
return res, nil
}
resBatch := out[0]
resList := make([]interface{}, len(idxValues))
for idx, idxVal := range idxValues {
res := resBatch.MapIndex(idxVal)
if !res.IsValid() || (res.Kind() == reflect.Ptr && res.IsNil()) {
if funcCtx.enforceNoNilResps {
return nil, fmt.Errorf("%s is marked non-nullable but returned a null value", funcCtx.funcType)
}
continue
}
resList[idx] = res.Interface()
}
return resList, nil
}