diff --git a/abstract_test.go b/abstract_test.go index 623b6cba..9d14c07a 100644 --- a/abstract_test.go +++ b/abstract_test.go @@ -158,7 +158,6 @@ func TestIsTypeOfUsedToResolveRuntimeTypeForInterface(t *testing.T) { } } - func TestAppendTypeUsedToAddRuntimeCustomScalarTypeForInterface(t *testing.T) { petType := graphql.NewInterface(graphql.InterfaceConfig{ @@ -247,7 +246,6 @@ func TestAppendTypeUsedToAddRuntimeCustomScalarTypeForInterface(t *testing.T) { }, }, }), - }) if err != nil { t.Fatalf("Error in schema %v", err.Error()) @@ -297,8 +295,6 @@ func TestAppendTypeUsedToAddRuntimeCustomScalarTypeForInterface(t *testing.T) { } } - - func TestIsTypeOfUsedToResolveRuntimeTypeForUnion(t *testing.T) { dogType := graphql.NewObject(graphql.ObjectConfig{ diff --git a/examples/custom-scalar-type/main.go b/examples/custom-scalar-type/main.go new file mode 100644 index 00000000..e7203a06 --- /dev/null +++ b/examples/custom-scalar-type/main.go @@ -0,0 +1,141 @@ +package main + +import ( + "encoding/json" + "fmt" + "log" + + "github.com/graphql-go/graphql" + "github.com/graphql-go/graphql/language/ast" +) + +type CustomID struct { + value string +} + +func (id *CustomID) String() string { + return id.value +} + +func NewCustomID(v string) *CustomID { + return &CustomID{value: v} +} + +var CustomScalarType = graphql.NewScalar(graphql.ScalarConfig{ + Name: "CustomScalarType", + Description: "The `CustomScalarType` scalar type represents an ID Object.", + // Serialize serializes `CustomID` to string. + Serialize: func(value interface{}) interface{} { + switch value := value.(type) { + case CustomID: + return value.String() + case *CustomID: + v := *value + return v.String() + default: + return nil + } + }, + // ParseValue parses GraphQL variables from `string` to `CustomID`. + ParseValue: func(value interface{}) interface{} { + switch value := value.(type) { + case string: + return NewCustomID(value) + case *string: + return NewCustomID(*value) + default: + return nil + } + }, + // ParseLiteral parses GraphQL AST value to `CustomID`. + ParseLiteral: func(valueAST ast.Value) interface{} { + switch valueAST := valueAST.(type) { + case *ast.StringValue: + return NewCustomID(valueAST.Value) + default: + return nil + } + }, +}) + +type Customer struct { + ID *CustomID `json:"id"` +} + +var CustomerType = graphql.NewObject(graphql.ObjectConfig{ + Name: "Customer", + Fields: graphql.Fields{ + "id": &graphql.Field{ + Type: CustomScalarType, + }, + }, +}) + +func main() { + schema, err := graphql.NewSchema(graphql.SchemaConfig{ + Query: graphql.NewObject(graphql.ObjectConfig{ + Name: "Query", + Fields: graphql.Fields{ + "customers": &graphql.Field{ + Type: graphql.NewList(CustomerType), + Args: graphql.FieldConfigArgument{ + "id": &graphql.ArgumentConfig{ + Type: CustomScalarType, + }, + }, + Resolve: func(p graphql.ResolveParams) (interface{}, error) { + // id := p.Args["id"] + // log.Printf("id from arguments: %+v", id) + customers := []Customer{ + Customer{ID: NewCustomID("fb278f2a4a13f")}, + } + return customers, nil + }, + }, + }, + }), + }) + if err != nil { + log.Fatal(err) + } + query := ` + query { + customers { + id + } + } + ` + /* + queryWithVariable := ` + query($id: CustomScalarType) { + customers(id: $id) { + id + } + } + ` + */ + /* + queryWithArgument := ` + query { + customers(id: "5b42ba57289") { + id + } + } + ` + */ + result := graphql.Do(graphql.Params{ + Schema: schema, + RequestString: query, + VariableValues: map[string]interface{}{ + "id": "5b42ba57289", + }, + }) + if len(result.Errors) > 0 { + log.Fatal(result) + } + b, err := json.Marshal(result) + if err != nil { + log.Fatal(err) + } + fmt.Println(string(b)) +} diff --git a/examples/modify-context/main.go b/examples/modify-context/main.go new file mode 100644 index 00000000..0432b3ba --- /dev/null +++ b/examples/modify-context/main.go @@ -0,0 +1,75 @@ +package main + +import ( + "context" + "encoding/json" + "fmt" + "log" + + "github.com/graphql-go/graphql" +) + +type User struct { + ID int `json:"id"` +} + +var UserType = graphql.NewObject(graphql.ObjectConfig{ + Name: "User", + Fields: graphql.Fields{ + "id": &graphql.Field{ + Type: graphql.Int, + Resolve: func(p graphql.ResolveParams) (interface{}, error) { + rootValue := p.Info.RootValue.(map[string]interface{}) + if rootValue["data-from-parent"] == "ok" && + rootValue["data-before-execution"] == "ok" { + user := p.Source.(User) + return user.ID, nil + } + return nil, nil + }, + }, + }, +}) + +func main() { + schema, err := graphql.NewSchema(graphql.SchemaConfig{ + Query: graphql.NewObject(graphql.ObjectConfig{ + Name: "Query", + Fields: graphql.Fields{ + "users": &graphql.Field{ + Type: graphql.NewList(UserType), + Resolve: func(p graphql.ResolveParams) (interface{}, error) { + rootValue := p.Info.RootValue.(map[string]interface{}) + rootValue["data-from-parent"] = "ok" + result := []User{ + User{ID: 1}, + } + return result, nil + + }, + }, + }, + }), + }) + if err != nil { + log.Fatal(err) + } + ctx := context.WithValue(context.Background(), "currentUser", User{ID: 100}) + // Instead of trying to modify context within a resolve function, use: + // `graphql.Params.RootObject` is a mutable optional variable and available on + // each resolve function via: `graphql.ResolveParams.Info.RootValue`. + rootObject := map[string]interface{}{ + "data-before-execution": "ok", + } + result := graphql.Do(graphql.Params{ + Context: ctx, + RequestString: "{ users { id } }", + RootObject: rootObject, + Schema: schema, + }) + b, err := json.Marshal(result) + if err != nil { + log.Fatal(err) + } + fmt.Printf("%s\n", string(b)) // {"data":{"users":[{"id":1}]}} +} diff --git a/executor.go b/executor.go index b6009a63..408e173e 100644 --- a/executor.go +++ b/executor.go @@ -567,12 +567,14 @@ func completeValueCatchingError(eCtx *executionContext, returnType Type, fieldAS func completeValue(eCtx *executionContext, returnType Type, fieldASTs []*ast.Field, info ResolveInfo, result interface{}) interface{} { resultVal := reflect.ValueOf(result) - if resultVal.IsValid() && resultVal.Type().Kind() == reflect.Func { + for resultVal.IsValid() && resultVal.Type().Kind() == reflect.Func { if propertyFn, ok := result.(func() interface{}); ok { - return propertyFn() + result = propertyFn() + resultVal = reflect.ValueOf(result) + } else { + err := gqlerrors.NewFormattedError("Error resolving func. Expected `func() interface{}` signature") + panic(gqlerrors.FormatError(err)) } - err := gqlerrors.NewFormattedError("Error resolving func. Expected `func() interface{}` signature") - panic(gqlerrors.FormatError(err)) } // If field type is NonNull, complete for inner type, and throw field error @@ -737,7 +739,7 @@ func completeListValue(eCtx *executionContext, returnType *List, fieldASTs []*as parentTypeName = info.ParentType.Name() } err := invariantf( - resultVal.IsValid() && resultVal.Type().Kind() == reflect.Slice, + resultVal.IsValid() && isIterable(result), "User Error: expected iterable, but did not find one "+ "for field %v.%v.", parentTypeName, info.FieldName) diff --git a/executor_test.go b/executor_test.go index 954d6d30..98741fdc 100644 --- a/executor_test.go +++ b/executor_test.go @@ -1807,3 +1807,90 @@ func TestContextDeadline(t *testing.T) { t.Fatalf("Unexpected result, Diff: %v", testutil.Diff(expectedErrors, result.Errors)) } } + +func TestThunkResultsProcessedCorrectly(t *testing.T) { + barType := graphql.NewObject(graphql.ObjectConfig{ + Name: "Bar", + Fields: graphql.Fields{ + "bazA": &graphql.Field{ + Type: graphql.String, + }, + "bazB": &graphql.Field{ + Type: graphql.String, + }, + }, + }) + + fooType := graphql.NewObject(graphql.ObjectConfig{ + Name: "Foo", + Fields: graphql.Fields{ + "bar": &graphql.Field{ + Type: barType, + Resolve: func(params graphql.ResolveParams) (interface{}, error) { + var bar struct { + BazA string + BazB string + } + bar.BazA = "A" + bar.BazB = "B" + + thunk := func() interface{} { return &bar } + return thunk, nil + }, + }, + }, + }) + + queryType := graphql.NewObject(graphql.ObjectConfig{ + Name: "Query", + Fields: graphql.Fields{ + "foo": &graphql.Field{ + Type: fooType, + Resolve: func(params graphql.ResolveParams) (interface{}, error) { + var foo struct{} + return foo, nil + }, + }, + }, + }) + + expectNoError := func(err error) { + if err != nil { + t.Fatalf("expected no error, got %v", err) + } + } + + schema, err := graphql.NewSchema(graphql.SchemaConfig{ + Query: queryType, + }) + expectNoError(err) + + query := "{ foo { bar { bazA bazB } } }" + result := graphql.Do(graphql.Params{ + Schema: schema, + RequestString: query, + }) + if len(result.Errors) != 0 { + t.Fatalf("expected no errors, got %v", result.Errors) + } + + foo := result.Data.(map[string]interface{})["foo"].(map[string]interface{}) + bar, ok := foo["bar"].(map[string]interface{}) + + if !ok { + t.Errorf("expected bar to be a map[string]interface{}: actual = %v", reflect.TypeOf(foo["bar"])) + } else { + if got, want := bar["bazA"], "A"; got != want { + t.Errorf("foo.bar.bazA: got=%v, want=%v", got, want) + } + if got, want := bar["bazB"], "B"; got != want { + t.Errorf("foo.bar.bazB: got=%v, want=%v", got, want) + } + } + + if t.Failed() { + b, err := json.Marshal(result.Data) + expectNoError(err) + t.Log(string(b)) + } +} diff --git a/lists_test.go b/lists_test.go index 7fd360c8..fa14121b 100644 --- a/lists_test.go +++ b/lists_test.go @@ -11,6 +11,8 @@ import ( ) func checkList(t *testing.T, testType graphql.Type, testData interface{}, expected *graphql.Result) { + // TODO: uncomment t.Helper when support for go1.8 is dropped. + //t.Helper() data := map[string]interface{}{ "test": testData, } @@ -561,8 +563,17 @@ func TestLists_NullableListOfNonNullArrayOfFunc_ContainsNulls(t *testing.T) { expected := &graphql.Result{ Data: map[string]interface{}{ "nest": map[string]interface{}{ - "test": []interface{}{ - 1, nil, 2, + "test": nil, + }, + }, + Errors: []gqlerrors.FormattedError{ + { + Message: "Cannot return null for non-nullable field DataType.test.", + Locations: []location.SourceLocation{ + { + Line: 1, + Column: 10, + }, }, }, }, @@ -752,9 +763,16 @@ func TestLists_NonNullListOfNonNullArrayOfFunc_ContainsNulls(t *testing.T) { } expected := &graphql.Result{ Data: map[string]interface{}{ - "nest": map[string]interface{}{ - "test": []interface{}{ - 1, nil, 2, + "nest": nil, + }, + Errors: []gqlerrors.FormattedError{ + { + Message: "Cannot return null for non-nullable field DataType.test.", + Locations: []location.SourceLocation{ + { + Line: 1, + Column: 10, + }, }, }, }, @@ -780,3 +798,20 @@ func TestLists_UserErrorExpectIterableButDidNotGetOne(t *testing.T) { } checkList(t, ttype, data, expected) } + +func TestLists_ArrayOfNullableObjects_ContainsValues(t *testing.T) { + ttype := graphql.NewList(graphql.Int) + data := [2]interface{}{ + 1, 2, + } + expected := &graphql.Result{ + Data: map[string]interface{}{ + "nest": map[string]interface{}{ + "test": []interface{}{ + 1, 2, + }, + }, + }, + } + checkList(t, ttype, data, expected) +} diff --git a/scalars.go b/scalars.go index 0465d1b2..1dc62fef 100644 --- a/scalars.go +++ b/scalars.go @@ -21,6 +21,8 @@ func coerceInt(value interface{}) interface{} { return 1 } return 0 + case *bool: + return coerceInt(*value) case int: if value < int(math.MinInt32) || value > int(math.MaxInt32) { return nil @@ -134,8 +136,44 @@ func coerceFloat(value interface{}) interface{} { return coerceFloat(*value) case int: return float64(value) + case *int: + return coerceFloat(*value) + case int8: + return float64(value) + case *int8: + return coerceFloat(*value) + case int16: + return float64(value) + case *int16: + return coerceFloat(*value) + case int32: + return float64(value) case *int32: return coerceFloat(*value) + case int64: + return float64(value) + case *int64: + return coerceFloat(*value) + case uint: + return float64(value) + case *uint: + return coerceFloat(*value) + case uint8: + return float64(value) + case *uint8: + return coerceFloat(*value) + case uint16: + return float64(value) + case *uint16: + return coerceFloat(*value) + case uint32: + return float64(value) + case *uint32: + return coerceFloat(*value) + case uint64: + return float64(value) + case *uint64: + return coerceFloat(*value) case float32: return value case *float32: @@ -153,7 +191,10 @@ func coerceFloat(value interface{}) interface{} { case *string: return coerceFloat(*value) } - return 0.0 + + // If the value cannot be transformed into an float, return nil instead of '0.0' + // to denote 'no float found' + return nil } // Float is the GraphQL float type definition. @@ -167,7 +208,7 @@ var Float = NewScalar(ScalarConfig{ ParseLiteral: func(valueAST ast.Value) interface{} { switch valueAST := valueAST.(type) { case *ast.FloatValue: - if floatValue, err := strconv.ParseFloat(valueAST.Value, 32); err == nil { + if floatValue, err := strconv.ParseFloat(valueAST.Value, 64); err == nil { return floatValue } case *ast.IntValue: @@ -238,6 +279,69 @@ func coerceBool(value interface{}) interface{} { return false case *int: return coerceBool(*value) + case int8: + if value != 0 { + return true + } + return false + case *int8: + return coerceBool(*value) + case int16: + if value != 0 { + return true + } + return false + case *int16: + return coerceBool(*value) + case int32: + if value != 0 { + return true + } + return false + case *int32: + return coerceBool(*value) + case int64: + if value != 0 { + return true + } + return false + case *int64: + return coerceBool(*value) + case uint: + if value != 0 { + return true + } + return false + case *uint: + return coerceBool(*value) + case uint8: + if value != 0 { + return true + } + return false + case *uint8: + return coerceBool(*value) + case uint16: + if value != 0 { + return true + } + return false + case *uint16: + return coerceBool(*value) + case uint32: + if value != 0 { + return true + } + return false + case *uint32: + return coerceBool(*value) + case uint64: + if value != 0 { + return true + } + return false + case *uint64: + return coerceBool(*value) } return false } diff --git a/scalars_test.go b/scalars_test.go new file mode 100644 index 00000000..fb136696 --- /dev/null +++ b/scalars_test.go @@ -0,0 +1,642 @@ +package graphql + +import ( + "math" + "testing" +) + +func TestCoerceInt(t *testing.T) { + tests := []struct { + in interface{} + want interface{} + }{ + { + in: false, + want: 0, + }, + { + in: true, + want: 1, + }, + { + in: boolPtr(false), + want: 0, + }, + { + in: boolPtr(true), + want: 1, + }, + { + in: int(math.MinInt32) - 1, + want: nil, + }, + { + in: int(math.MaxInt32) + 1, + want: nil, + }, + { + in: uint(math.MaxInt32) + 1, + want: nil, + }, + { + in: uint32(math.MaxInt32) + 1, + want: nil, + }, + { + in: int64(math.MinInt32) - 1, + want: nil, + }, + { + in: int64(math.MaxInt32) + 1, + want: nil, + }, + { + in: uint64(math.MaxInt32) + 1, + want: nil, + }, + { + // need to subtract more than one because of float32 precision + in: float32(math.MinInt32) - 1000, + want: nil, + }, + { + // need to add more than one because of float32 precision + in: float32(math.MaxInt32) + 1000, + want: nil, + }, + { + in: float64(math.MinInt32) - 1, + want: nil, + }, + { + in: float64(math.MaxInt32) + 1, + want: nil, + }, + { + in: int(math.MinInt32), + want: int(math.MinInt32), + }, + { + in: int(math.MaxInt32), + want: int(math.MaxInt32), + }, + { + in: intPtr(12), + want: 12, + }, + { + in: int8(13), + want: int(13), + }, + { + in: int8Ptr(14), + want: int(14), + }, + { + in: int16(15), + want: int(15), + }, + { + in: int16Ptr(16), + want: int(16), + }, + { + in: int32(17), + want: int(17), + }, + { + in: int32Ptr(18), + want: int(18), + }, + { + in: int64(19), + want: int(19), + }, + { + in: int64Ptr(20), + want: int(20), + }, + { + in: uint8(21), + want: int(21), + }, + { + in: uint8Ptr(22), + want: int(22), + }, + { + in: uint16(23), + want: int(23), + }, + { + in: uint16Ptr(24), + want: int(24), + }, + { + in: uint32(25), + want: int(25), + }, + { + in: uint32Ptr(26), + want: int(26), + }, + { + in: uint64(27), + want: int(27), + }, + { + in: uint64Ptr(28), + want: int(28), + }, + { + in: uintPtr(29), + want: int(29), + }, + { + in: float32(30.1), + want: int(30), + }, + { + in: float32Ptr(31.2), + want: int(31), + }, + { + in: float64(32), + want: int(32), + }, + { + in: float64Ptr(33.1), + want: int(33), + }, + { + in: "34", + want: int(34), + }, + { + in: stringPtr("35"), + want: int(35), + }, + { + in: "I'm not a number", + want: nil, + }, + { + in: make(map[string]interface{}), + want: nil, + }, + } + + for i, tt := range tests { + if got, want := coerceInt(tt.in), tt.want; got != want { + t.Errorf("%d: in=%v, got=%v, want=%v", i, tt.in, got, want) + } + } +} + +func TestCoerceFloat(t *testing.T) { + tests := []struct { + in interface{} + want interface{} + }{ + { + in: false, + want: 0.0, + }, + { + in: true, + want: 1.0, + }, + { + in: boolPtr(false), + want: 0.0, + }, + { + in: boolPtr(true), + want: 1.0, + }, + { + in: int(math.MinInt32), + want: float64(math.MinInt32), + }, + { + in: int(math.MaxInt32), + want: float64(math.MaxInt32), + }, + { + in: intPtr(12), + want: float64(12), + }, + { + in: int8(13), + want: float64(13), + }, + { + in: int8Ptr(14), + want: float64(14), + }, + { + in: int16(15), + want: float64(15), + }, + { + in: int16Ptr(16), + want: float64(16), + }, + { + in: int32(17), + want: float64(17), + }, + { + in: int32Ptr(18), + want: float64(18), + }, + { + in: int64(19), + want: float64(19), + }, + { + in: int64Ptr(20), + want: float64(20), + }, + { + in: uint8(21), + want: float64(21), + }, + { + in: uint8Ptr(22), + want: float64(22), + }, + { + in: uint16(23), + want: float64(23), + }, + { + in: uint16Ptr(24), + want: float64(24), + }, + { + in: uint32(25), + want: float64(25), + }, + { + in: uint32Ptr(26), + want: float64(26), + }, + { + in: uint64(27), + want: float64(27), + }, + { + in: uint64Ptr(28), + want: float64(28), + }, + { + in: uintPtr(29), + want: float64(29), + }, + { + in: float32(30), + want: float32(30), + }, + { + in: float32Ptr(31), + want: float32(31), + }, + { + in: float64(32), + want: float64(32), + }, + { + in: float64Ptr(33.2), + want: float64(33.2), + }, + { + in: "34", + want: float64(34), + }, + { + in: stringPtr("35.2"), + want: float64(35.2), + }, + { + in: "I'm not a number", + want: nil, + }, + { + in: make(map[string]interface{}), + want: nil, + }, + } + + for i, tt := range tests { + if got, want := coerceFloat(tt.in), tt.want; got != want { + t.Errorf("%d: in=%v, got=%v, want=%v", i, tt.in, got, want) + } + } +} + +func TestCoerceBool(t *testing.T) { + tests := []struct { + in interface{} + want interface{} + }{ + { + in: false, + want: false, + }, + { + in: true, + want: true, + }, + { + in: boolPtr(false), + want: false, + }, + { + in: boolPtr(true), + want: true, + }, + { + in: int(math.MinInt32), + want: true, + }, + { + in: int(math.MaxInt32), + want: true, + }, + { + in: int(0), + want: false, + }, + { + in: intPtr(12), + want: true, + }, + { + in: intPtr(0), + want: false, + }, + { + in: int8(13), + want: true, + }, + { + in: int8(0), + want: false, + }, + { + in: int8Ptr(14), + want: true, + }, + { + in: int8Ptr(0), + want: false, + }, + { + in: int16(15), + want: true, + }, + { + in: int16(0), + want: false, + }, + { + in: int16Ptr(16), + want: true, + }, + { + in: int16Ptr(0), + want: false, + }, + { + in: int32(17), + want: true, + }, + { + in: int32(0), + want: false, + }, + { + in: int32Ptr(18), + want: true, + }, + { + in: int32Ptr(0), + want: false, + }, + { + in: int64(19), + want: true, + }, + { + in: int64(0), + want: false, + }, + { + in: int64Ptr(20), + want: true, + }, + { + in: int64Ptr(0), + want: false, + }, + { + in: uint8(21), + want: true, + }, + { + in: uint8(0), + want: false, + }, + { + in: uint8Ptr(22), + want: true, + }, + { + in: uint8Ptr(0), + want: false, + }, + { + in: uint16(23), + want: true, + }, + { + in: uint16(0), + want: false, + }, + { + in: uint16Ptr(24), + want: true, + }, + { + in: uint16Ptr(0), + want: false, + }, + { + in: uint32(25), + want: true, + }, + { + in: uint32(0), + want: false, + }, + { + in: uint32Ptr(26), + want: true, + }, + { + in: uint32Ptr(0), + want: false, + }, + { + in: uint64(27), + want: true, + }, + { + in: uint64(0), + want: false, + }, + { + in: uint64Ptr(28), + want: true, + }, + { + in: uint64Ptr(0), + want: false, + }, + { + in: uintPtr(29), + want: true, + }, + { + in: uintPtr(0), + want: false, + }, + { + in: float32(30), + want: true, + }, + { + in: float32(0), + want: false, + }, + { + in: float32Ptr(31), + want: true, + }, + { + in: float32Ptr(0), + want: false, + }, + { + in: float64(32), + want: true, + }, + { + in: float64(0), + want: false, + }, + { + in: float64Ptr(33.2), + want: true, + }, + { + in: float64Ptr(0), + want: false, + }, + { + in: "34", + want: true, + }, + { + in: "false", + want: false, + }, + { + in: stringPtr("true"), + want: true, + }, + { + in: stringPtr("false"), + want: false, + }, + { + in: "I'm some random string", + want: true, + }, + { + in: "", + want: false, + }, + { + in: int8(0), + want: false, + }, + { + in: make(map[string]interface{}), + want: false, + }, + } + + for i, tt := range tests { + if got, want := coerceBool(tt.in), tt.want; got != want { + t.Errorf("%d: in=%v, got=%v, want=%v", i, tt.in, got, want) + } + } +} + +func boolPtr(b bool) *bool { + return &b +} + +func intPtr(n int) *int { + return &n +} + +func int8Ptr(n int8) *int8 { + return &n +} + +func int16Ptr(n int16) *int16 { + return &n +} + +func int32Ptr(n int32) *int32 { + return &n +} + +func int64Ptr(n int64) *int64 { + return &n +} + +func uintPtr(n uint) *uint { + return &n +} + +func uint8Ptr(n uint8) *uint8 { + return &n +} + +func uint16Ptr(n uint16) *uint16 { + return &n +} + +func uint32Ptr(n uint32) *uint32 { + return &n +} + +func uint64Ptr(n uint64) *uint64 { + return &n +} + +func float32Ptr(n float32) *float32 { + return &n +} + +func float64Ptr(n float64) *float64 { + return &n +} + +func stringPtr(s string) *string { + return &s +} diff --git a/values.go b/values.go index bff25e9b..a68866e1 100644 --- a/values.go +++ b/values.go @@ -318,6 +318,15 @@ func isNullish(src interface{}) bool { return false } +// Returns true if src is a slice or an array +func isIterable(src interface{}) bool { + if src == nil { + return false + } + t := reflect.TypeOf(src) + return t.Kind() == reflect.Slice || t.Kind() == reflect.Array +} + /** * Produces a value given a GraphQL Value AST. * diff --git a/values_test.go b/values_test.go new file mode 100644 index 00000000..015599a5 --- /dev/null +++ b/values_test.go @@ -0,0 +1,18 @@ +package graphql + +import "testing" + +func TestIsIterable(t *testing.T) { + if !isIterable([]int{}) { + t.Fatal("expected isIterable to return true for a slice, got false") + } + if !isIterable([]int{}) { + t.Fatal("expected isIterable to return true for an array, got false") + } + if isIterable(1) { + t.Fatal("expected isIterable to return false for an int, got true") + } + if isIterable(nil) { + t.Fatal("expected isIterable to return false for nil, got true") + } +}