diff --git a/definition.go b/definition.go index c7aaa3be..bf42b530 100644 --- a/definition.go +++ b/definition.go @@ -1,7 +1,6 @@ package graphql import ( - "errors" "fmt" "reflect" "regexp" @@ -10,7 +9,7 @@ import ( "golang.org/x/net/context" ) -// These are all of the possible kinds of +// Type interface for all of the possible kinds of GraphQL types type Type interface { Name() string Description() string @@ -28,7 +27,7 @@ var _ Type = (*List)(nil) var _ Type = (*NonNull)(nil) var _ Type = (*Argument)(nil) -// These types may be used as input types for arguments and directives. +// Input interface for types that may be used as input types for arguments and directives. type Input interface { Name() string Description() string @@ -42,6 +41,7 @@ var _ Input = (*InputObject)(nil) var _ Input = (*List)(nil) var _ Input = (*NonNull)(nil) +// IsInputType determines if given type is a GraphQLInputType func IsInputType(ttype Type) bool { named := GetNamed(ttype) if _, ok := named.(*Scalar); ok { @@ -56,6 +56,7 @@ func IsInputType(ttype Type) bool { return false } +// IsOutputType determines if given type is a GraphQLOutputType func IsOutputType(ttype Type) bool { name := GetNamed(ttype) if _, ok := name.(*Scalar); ok { @@ -76,6 +77,7 @@ func IsOutputType(ttype Type) bool { return false } +// IsLeafType determines if given type is a leaf value func IsLeafType(ttype Type) bool { named := GetNamed(ttype) if _, ok := named.(*Scalar); ok { @@ -87,7 +89,7 @@ func IsLeafType(ttype Type) bool { return false } -// These types may be used as output types as the result of fields. +// Output interface for types that may be used as output types as the result of fields. type Output interface { Name() string Description() string @@ -103,7 +105,7 @@ var _ Output = (*Enum)(nil) var _ Output = (*List)(nil) var _ Output = (*NonNull)(nil) -// These types may describe the parent context of a selection set. +// Composite interface for types that may describe the parent context of a selection set. type Composite interface { Name() string } @@ -112,6 +114,7 @@ var _ Composite = (*Object)(nil) var _ Composite = (*Interface)(nil) var _ Composite = (*Union)(nil) +// IsCompositeType determines if given type is a GraphQLComposite type func IsCompositeType(ttype interface{}) bool { if _, ok := ttype.(*Object); ok { return true @@ -125,8 +128,9 @@ func IsCompositeType(ttype interface{}) bool { return false } -// These types may describe the parent context of a selection set. +// Abstract interface for types that may describe the parent context of a selection set. type Abstract interface { + Name() string ObjectType(value interface{}, info ResolveInfo) *Object PossibleTypes() []*Object IsPossibleType(ttype *Object) bool @@ -135,6 +139,7 @@ type Abstract interface { var _ Abstract = (*Interface)(nil) var _ Abstract = (*Union)(nil) +// Nullable interface for types that can accept null as a value. type Nullable interface { } @@ -146,6 +151,7 @@ var _ Nullable = (*Enum)(nil) var _ Nullable = (*InputObject)(nil) var _ Nullable = (*List)(nil) +// GetNullable returns the Nullable type of the given GraphQL type func GetNullable(ttype Type) Nullable { if ttype, ok := ttype.(*NonNull); ok { return ttype.OfType @@ -153,7 +159,7 @@ func GetNullable(ttype Type) Nullable { return ttype } -// These named types do not include modifiers like List or NonNull. +// Named interface for types that do not include modifiers like List or NonNull. type Named interface { String() string } @@ -165,6 +171,7 @@ var _ Named = (*Union)(nil) var _ Named = (*Enum)(nil) var _ Named = (*InputObject)(nil) +// GetNamed returns the Named type of the given GraphQL type func GetNamed(ttype Type) Named { unmodifiedType := ttype for { @@ -181,23 +188,21 @@ func GetNamed(ttype Type) Named { return unmodifiedType } -/** - * Scalar Type Definition - * - * The leaf values of any request and input values to arguments are - * Scalars (or Enums) and are defined with a name and a series of functions - * used to parse input from ast or variables and to ensure validity. - * - * Example: - * - * var OddType = new Scalar({ - * name: 'Odd', - * serialize(value) { - * return value % 2 === 1 ? value : null; - * } - * }); - * - */ +// Scalar Type Definition +// +// The leaf values of any request and input values to arguments are +// Scalars (or Enums) and are defined with a name and a series of functions +// used to parse input from ast or variables and to ensure validity. +// +// Example: +// +// var OddType = new Scalar({ +// name: 'Odd', +// serialize(value) { +// return value % 2 === 1 ? value : null; +// } +// }); +// type Scalar struct { PrivateName string `json:"name"` PrivateDescription string `json:"description"` @@ -205,9 +210,17 @@ type Scalar struct { scalarConfig ScalarConfig err error } + +// SerializeFn is a function type for serializing a GraphQLScalar type value type SerializeFn func(value interface{}) interface{} + +// ParseValueFn is a function type for parsing the value of a GraphQLScalar type type ParseValueFn func(value interface{}) interface{} + +// ParseLiteralFn is a function type for parsing the literal value of a GraphQLScalar type type ParseLiteralFn func(valueAST ast.Value) interface{} + +// ScalarConfig options for creating a new GraphQLScalar type ScalarConfig struct { Name string `json:"name"` Description string `json:"description"` @@ -216,6 +229,7 @@ type ScalarConfig struct { ParseLiteral ParseLiteralFn } +// NewScalar creates a new GraphQLScalar func NewScalar(config ScalarConfig) *Scalar { st := &Scalar{} err := invariant(config.Name != "", "Type must be named.") @@ -289,43 +303,41 @@ func (st *Scalar) Error() error { return st.err } -/** - * Object Type Definition - * - * Almost all of the GraphQL types you define will be object Object types - * have a name, but most importantly describe their fields. - * - * Example: - * - * var AddressType = new Object({ - * name: 'Address', - * fields: { - * street: { type: String }, - * number: { type: Int }, - * formatted: { - * type: String, - * resolve(obj) { - * return obj.number + ' ' + obj.street - * } - * } - * } - * }); - * - * When two types need to refer to each other, or a type needs to refer to - * itself in a field, you can use a function expression (aka a closure or a - * thunk) to supply the fields lazily. - * - * Example: - * - * var PersonType = new Object({ - * name: 'Person', - * fields: () => ({ - * name: { type: String }, - * bestFriend: { type: PersonType }, - * }) - * }); - * - */ +// Object Type Definition +// +// Almost all of the GraphQL types you define will be object Object types +// have a name, but most importantly describe their fields. +// Example: +// +// var AddressType = new Object({ +// name: 'Address', +// fields: { +// street: { type: String }, +// number: { type: Int }, +// formatted: { +// type: String, +// resolve(obj) { +// return obj.number + ' ' + obj.street +// } +// } +// } +// }); +// +// When two types need to refer to each other, or a type needs to refer to +// itself in a field, you can use a function expression (aka a closure or a +// thunk) to supply the fields lazily. +// +// Example: +// +// var PersonType = new Object({ +// name: 'Person', +// fields: () => ({ +// name: { type: String }, +// bestFriend: { type: PersonType }, +// }) +// }); +// +// / type Object struct { PrivateName string `json:"name"` PrivateDescription string `json:"description"` @@ -428,7 +440,7 @@ func (gt *Object) Interfaces() []*Interface { configInterfaces = gt.typeConfig.Interfaces.([]*Interface) case nil: default: - gt.err = errors.New(fmt.Sprintf("Unknown Object.Interfaces type: %v", reflect.TypeOf(gt.typeConfig.Interfaces))) + gt.err = fmt.Errorf("Unknown Object.Interfaces type: %v", reflect.TypeOf(gt.typeConfig.Interfaces)) return nil } interfaces, err := defineInterfaces(gt, configInterfaces) @@ -543,6 +555,7 @@ func defineFieldMap(ttype Named, fields Fields) (FieldDefinitionMap, error) { return resultFieldMap, nil } +// ResolveParams Params for FieldResolveFn() // TODO: clean up GQLFRParams fields type ResolveParams struct { Source interface{} @@ -554,7 +567,6 @@ type ResolveParams struct { Context context.Context } -// TODO: relook at FieldResolveFn params type FieldResolveFn func(p ResolveParams) (interface{}, error) type ResolveInfo struct { @@ -626,24 +638,23 @@ func (st *Argument) Error() error { return nil } -/** - * Interface Type Definition - * - * When a field can return one of a heterogeneous set of types, a Interface type - * is used to describe what types are possible, what fields are in common across - * all types, as well as a function to determine which type is actually used - * when the field is resolved. - * - * Example: - * - * var EntityType = new Interface({ - * name: 'Entity', - * fields: { - * name: { type: String } - * } - * }); - * - */ +// Interface Type Definition +// +// When a field can return one of a heterogeneous set of types, a Interface type +// is used to describe what types are possible, what fields are in common across +// all types, as well as a function to determine which type is actually used +// when the field is resolved. +// +// Example: +// +// var EntityType = new Interface({ +// name: 'Entity', +// fields: { +// name: { type: String } +// } +// }); +// +// type Interface struct { PrivateName string `json:"name"` PrivateDescription string `json:"description"` @@ -657,11 +668,12 @@ type Interface struct { err error } type InterfaceConfig struct { - Name string `json:"name"` - Fields Fields `json:"fields"` + Name string `json:"name"` + Fields interface{} `json:"fields"` ResolveType ResolveTypeFn Description string `json:"description"` } + type ResolveTypeFn func(value interface{}, info ResolveInfo) *Object func NewInterface(config InterfaceConfig) *Interface { @@ -690,7 +702,10 @@ func (it *Interface) AddFieldConfig(fieldName string, fieldConfig *Field) { if fieldName == "" || fieldConfig == nil { return } - it.typeConfig.Fields[fieldName] = fieldConfig + switch it.typeConfig.Fields.(type) { + case Fields: + it.typeConfig.Fields.(Fields)[fieldName] = fieldConfig + } } func (it *Interface) Name() string { return it.PrivateName @@ -699,7 +714,16 @@ func (it *Interface) Description() string { return it.PrivateDescription } func (it *Interface) Fields() (fields FieldDefinitionMap) { - it.fields, it.err = defineFieldMap(it, it.typeConfig.Fields) + var configureFields Fields + switch it.typeConfig.Fields.(type) { + case Fields: + configureFields = it.typeConfig.Fields.(Fields) + case FieldsThunk: + configureFields = it.typeConfig.Fields.(FieldsThunk)() + } + fields, err := defineFieldMap(it, configureFields) + it.err = err + it.fields = fields return it.fields } func (it *Interface) PossibleTypes() []*Object { @@ -750,29 +774,26 @@ func getTypeOf(value interface{}, info ResolveInfo, abstractType Abstract) *Obje return nil } -/** - * Union Type Definition - * - * When a field can return one of a heterogeneous set of types, a Union type - * is used to describe what types are possible as well as providing a function - * to determine which type is actually used when the field is resolved. - * - * Example: - * - * var PetType = new Union({ - * name: 'Pet', - * types: [ DogType, CatType ], - * resolveType(value) { - * if (value instanceof Dog) { - * return DogType; - * } - * if (value instanceof Cat) { - * return CatType; - * } - * } - * }); - * - */ +// Union Type Definition +// +// When a field can return one of a heterogeneous set of types, a Union type +// is used to describe what types are possible as well as providing a function +// to determine which type is actually used when the field is resolved. +// +// Example: +// +// var PetType = new Union({ +// name: 'Pet', +// types: [ DogType, CatType ], +// resolveType(value) { +// if (value instanceof Dog) { +// return DogType; +// } +// if (value instanceof Cat) { +// return CatType; +// } +// } +// }); type Union struct { PrivateName string `json:"name"` PrivateDescription string `json:"description"` @@ -887,27 +908,26 @@ func (ut *Union) Error() error { return ut.err } -/** - * Enum Type Definition - * - * Some leaf values of requests and input values are Enums. GraphQL serializes - * Enum values as strings, however internally Enums can be represented by any - * kind of type, often integers. - * - * Example: - * - * var RGBType = new Enum({ - * name: 'RGB', - * values: { - * RED: { value: 0 }, - * GREEN: { value: 1 }, - * BLUE: { value: 2 } - * } - * }); - * - * Note: If a value is not provided in a definition, the name of the enum value - * will be used as it's internal value. - */ +// Enum Type Definition +// +// Some leaf values of requests and input values are Enums. GraphQL serializes +// Enum values as strings, however internally Enums can be represented by any +// kind of type, often integers. +// +// Example: +// +// var RGBType = new Enum({ +// name: 'RGB', +// values: { +// RED: { value: 0 }, +// GREEN: { value: 1 }, +// BLUE: { value: 2 } +// } +// }); +// +// Note: If a value is not provided in a definition, the name of the enum value +// will be used as its internal value. + type Enum struct { PrivateName string `json:"name"` PrivateDescription string `json:"description"` @@ -1057,26 +1077,23 @@ func (gt *Enum) getNameLookup() map[string]*EnumValueDefinition { return gt.nameLookup } -/** - * Input Object Type Definition - * - * An input object defines a structured collection of fields which may be - * supplied to a field argument. - * - * Using `NonNull` will ensure that a value must be provided by the query - * - * Example: - * - * var GeoPoint = new InputObject({ - * name: 'GeoPoint', - * fields: { - * lat: { type: new NonNull(Float) }, - * lon: { type: new NonNull(Float) }, - * alt: { type: Float, defaultValue: 0 }, - * } - * }); - * - */ +// InputObject Type Definition +// +// An input object defines a structured collection of fields which may be +// supplied to a field argument. +// +// Using `NonNull` will ensure that a value must be provided by the query +// +// Example: +// +// var GeoPoint = new InputObject({ +// name: 'GeoPoint', +// fields: { +// lat: { type: new NonNull(Float) }, +// lon: { type: new NonNull(Float) }, +// alt: { type: Float, defaultValue: 0 }, +// } +// }); type InputObject struct { PrivateName string `json:"name"` PrivateDescription string `json:"description"` @@ -1121,7 +1138,6 @@ type InputObjectConfig struct { Description string `json:"description"` } -// TODO: rename InputObjectConfig to GraphQLInputObjecTypeConfig for consistency? func NewInputObject(config InputObjectConfig) *InputObject { gt := &InputObject{} err := invariant(config.Name != "", "Type must be named.") @@ -1197,24 +1213,22 @@ func (gt *InputObject) Error() error { return gt.err } -/** - * List Modifier - * - * A list is a kind of type marker, a wrapping type which points to another - * type. Lists are often created within the context of defining the fields of - * an object type. - * - * Example: - * - * var PersonType = new Object({ - * name: 'Person', - * fields: () => ({ - * parents: { type: new List(Person) }, - * children: { type: new List(Person) }, - * }) - * }) - * - */ +// List Modifier +// +// A list is a kind of type marker, a wrapping type which points to another +// type. Lists are often created within the context of defining the fields of +// an object type. +// +// Example: +// +// var PersonType = new Object({ +// name: 'Person', +// fields: () => ({ +// parents: { type: new List(Person) }, +// children: { type: new List(Person) }, +// }) +// }) +// type List struct { OfType Type `json:"ofType"` @@ -1249,26 +1263,24 @@ func (gl *List) Error() error { return gl.err } -/** - * Non-Null Modifier - * - * A non-null is a kind of type marker, a wrapping type which points to another - * type. Non-null types enforce that their values are never null and can ensure - * an error is raised if this ever occurs during a request. It is useful for - * fields which you can make a strong guarantee on non-nullability, for example - * usually the id field of a database row will never be null. - * - * Example: - * - * var RowType = new Object({ - * name: 'Row', - * fields: () => ({ - * id: { type: new NonNull(String) }, - * }) - * }) - * - * Note: the enforcement of non-nullability occurs within the executor. - */ +// NonNull Modifier +// +// A non-null is a kind of type marker, a wrapping type which points to another +// type. Non-null types enforce that their values are never null and can ensure +// an error is raised if this ever occurs during a request. It is useful for +// fields which you can make a strong guarantee on non-nullability, for example +// usually the id field of a database row will never be null. +// +// Example: +// +// var RowType = new Object({ +// name: 'Row', +// fields: () => ({ +// id: { type: new NonNull(String) }, +// }) +// }) +// +// Note: the enforcement of non-nullability occurs within the executor. type NonNull struct { PrivateName string `json:"name"` // added to conform with introspection for NonNull.Name = nil OfType Type `json:"ofType"` @@ -1304,11 +1316,11 @@ func (gl *NonNull) Error() error { return gl.err } -var NAME_REGEXP, _ = regexp.Compile("^[_a-zA-Z][_a-zA-Z0-9]*$") +var NameRegExp, _ = regexp.Compile("^[_a-zA-Z][_a-zA-Z0-9]*$") func assertValidName(name string) error { return invariant( - NAME_REGEXP.MatchString(name), + NameRegExp.MatchString(name), fmt.Sprintf(`Names must match /^[_a-zA-Z][_a-zA-Z0-9]*$/ but "%v" does not.`, name), ) } diff --git a/definition_test.go b/definition_test.go index 6664feab..363c8024 100644 --- a/definition_test.go +++ b/definition_test.go @@ -98,6 +98,20 @@ var blogMutation = graphql.NewObject(graphql.ObjectConfig{ }, }) +var blogSubscription = graphql.NewObject(graphql.ObjectConfig{ + Name: "Subscription", + Fields: graphql.Fields{ + "articleSubscribe": &graphql.Field{ + Type: blogArticle, + Args: graphql.FieldConfigArgument{ + "id": &graphql.ArgumentConfig{ + Type: graphql.String, + }, + }, + }, + }, +}) + var objectType = graphql.NewObject(graphql.ObjectConfig{ Name: "Object", IsTypeOf: func(value interface{}, info graphql.ResolveInfo) bool { @@ -204,6 +218,7 @@ func TestTypeSystem_DefinitionExample_DefinesAQueryOnlySchema(t *testing.T) { t.Fatalf("feedField.Name expected to equal `feed`, got: %v", feedField.Name) } } + func TestTypeSystem_DefinitionExample_DefinesAMutationScheme(t *testing.T) { blogSchema, err := graphql.NewSchema(graphql.SchemaConfig{ Query: blogQuery, @@ -233,6 +248,35 @@ func TestTypeSystem_DefinitionExample_DefinesAMutationScheme(t *testing.T) { } } +func TestTypeSystem_DefinitionExample_DefinesASubscriptionScheme(t *testing.T) { + blogSchema, err := graphql.NewSchema(graphql.SchemaConfig{ + Query: blogQuery, + Subscription: blogSubscription, + }) + if err != nil { + t.Fatalf("unexpected error, got: %v", err) + } + + if blogSchema.SubscriptionType() != blogSubscription { + t.Fatalf("expected blogSchema.SubscriptionType() == blogSubscription") + } + + subMutation, _ := blogSubscription.Fields()["articleSubscribe"] + if subMutation == nil { + t.Fatalf("subMutation is nil") + } + subMutationType := subMutation.Type + if subMutationType != blogArticle { + t.Fatalf("subMutationType expected to equal blogArticle, got: %v", subMutationType) + } + if subMutationType.Name() != "Article" { + t.Fatalf("subMutationType.Name expected to equal `Article`, got: %v", subMutationType.Name()) + } + if subMutation.Name != "articleSubscribe" { + t.Fatalf("subMutation.Name expected to equal `articleSubscribe`, got: %v", subMutation.Name) + } +} + func TestTypeSystem_DefinitionExample_IncludesNestedInputObjectsInTheMap(t *testing.T) { nestedInputObject := graphql.NewInputObject(graphql.InputObjectConfig{ Name: "NestedInputObject", @@ -263,9 +307,23 @@ func TestTypeSystem_DefinitionExample_IncludesNestedInputObjectsInTheMap(t *test }, }, }) + someSubscription := graphql.NewObject(graphql.ObjectConfig{ + Name: "SomeSubscription", + Fields: graphql.Fields{ + "subscribeToSomething": &graphql.Field{ + Type: blogArticle, + Args: graphql.FieldConfigArgument{ + "input": &graphql.ArgumentConfig{ + Type: someInputObject, + }, + }, + }, + }, + }) schema, err := graphql.NewSchema(graphql.SchemaConfig{ - Query: blogQuery, - Mutation: someMutation, + Query: blogQuery, + Mutation: someMutation, + Subscription: someSubscription, }) if err != nil { t.Fatalf("unexpected error, got: %v", err) diff --git a/directives.go b/directives.go index 67a1aa0d..63411104 100644 --- a/directives.go +++ b/directives.go @@ -1,5 +1,7 @@ package graphql +// Directive structs are used by the GraphQL runtime as a way of modifying execution +// behavior. Type system creators will usually not create these directly. type Directive struct { Name string `json:"name"` Description string `json:"description"` @@ -9,10 +11,6 @@ type Directive struct { OnField bool `json:"onField"` } -/** - * Directives are used by the GraphQL runtime as a way of modifying execution - * behavior. Type system creators will usually not create these directly. - */ func NewDirective(config *Directive) *Directive { if config == nil { config = &Directive{} @@ -27,10 +25,8 @@ func NewDirective(config *Directive) *Directive { } } -/** - * Used to conditionally include fields or fragments - */ -var IncludeDirective *Directive = NewDirective(&Directive{ +// IncludeDirective is used to conditionally include fields or fragments +var IncludeDirective = NewDirective(&Directive{ Name: "include", Description: "Directs the executor to include this field or fragment only when " + "the `if` argument is true.", @@ -46,10 +42,8 @@ var IncludeDirective *Directive = NewDirective(&Directive{ OnField: true, }) -/** - * Used to conditionally skip (exclude) fields or fragments - */ -var SkipDirective *Directive = NewDirective(&Directive{ +// SkipDirective Used to conditionally skip (exclude) fields or fragments +var SkipDirective = NewDirective(&Directive{ Name: "skip", Description: "Directs the executor to skip this field or fragment when the `if` " + "argument is true.", diff --git a/directives_test.go b/directives_test.go index 5c87aa5d..f376015c 100644 --- a/directives_test.go +++ b/directives_test.go @@ -324,6 +324,100 @@ func TestDirectivesWorksOnInlineFragmentUnlessTrueIncludesInlineFragment(t *test } } +func TestDirectivesWorksOnAnonymousInlineFragmentIfFalseOmitsAnonymousInlineFragment(t *testing.T) { + query := ` + query Q { + a + ... @include(if: false) { + b + } + } + ` + expected := &graphql.Result{ + Data: map[string]interface{}{ + "a": "a", + }, + } + result := executeDirectivesTestQuery(t, query) + if len(result.Errors) != 0 { + t.Fatalf("wrong result, unexpected errors: %v", result.Errors) + } + if !reflect.DeepEqual(expected, result) { + t.Fatalf("Unexpected result, Diff: %v", testutil.Diff(expected, result)) + } +} + +func TestDirectivesWorksOnAnonymousInlineFragmentIfTrueIncludesAnonymousInlineFragment(t *testing.T) { + query := ` + query Q { + a + ... @include(if: true) { + b + } + } + ` + expected := &graphql.Result{ + Data: map[string]interface{}{ + "a": "a", + "b": "b", + }, + } + result := executeDirectivesTestQuery(t, query) + if len(result.Errors) != 0 { + t.Fatalf("wrong result, unexpected errors: %v", result.Errors) + } + if !reflect.DeepEqual(expected, result) { + t.Fatalf("Unexpected result, Diff: %v", testutil.Diff(expected, result)) + } +} + +func TestDirectivesWorksOnAnonymousInlineFragmentUnlessFalseIncludesAnonymousInlineFragment(t *testing.T) { + query := ` + query Q { + a + ... @skip(if: false) { + b + } + } + ` + expected := &graphql.Result{ + Data: map[string]interface{}{ + "a": "a", + "b": "b", + }, + } + result := executeDirectivesTestQuery(t, query) + if len(result.Errors) != 0 { + t.Fatalf("wrong result, unexpected errors: %v", result.Errors) + } + if !reflect.DeepEqual(expected, result) { + t.Fatalf("Unexpected result, Diff: %v", testutil.Diff(expected, result)) + } +} + +func TestDirectivesWorksOnAnonymousInlineFragmentUnlessTrueIncludesAnonymousInlineFragment(t *testing.T) { + query := ` + query Q { + a + ... @skip(if: true) { + b + } + } + ` + expected := &graphql.Result{ + Data: map[string]interface{}{ + "a": "a", + }, + } + result := executeDirectivesTestQuery(t, query) + if len(result.Errors) != 0 { + t.Fatalf("wrong result, unexpected errors: %v", result.Errors) + } + if !reflect.DeepEqual(expected, result) { + t.Fatalf("Unexpected result, Diff: %v", testutil.Diff(expected, result)) + } +} + func TestDirectivesWorksOnFragmentIfFalseOmitsFragment(t *testing.T) { query := ` query Q { diff --git a/enum_type_test.go b/enum_type_test.go index 7187a686..b436dd09 100644 --- a/enum_type_test.go +++ b/enum_type_test.go @@ -6,6 +6,7 @@ import ( "github.com/graphql-go/graphql" "github.com/graphql-go/graphql/gqlerrors" + "github.com/graphql-go/graphql/language/location" "github.com/graphql-go/graphql/testutil" ) @@ -93,9 +94,31 @@ var enumTypeTestMutationType = graphql.NewObject(graphql.ObjectConfig{ }, }, }) + +var enumTypeTestSubscriptionType = graphql.NewObject(graphql.ObjectConfig{ + Name: "Subscription", + Fields: graphql.Fields{ + "subscribeToEnum": &graphql.Field{ + Type: enumTypeTestColorType, + Args: graphql.FieldConfigArgument{ + "color": &graphql.ArgumentConfig{ + Type: enumTypeTestColorType, + }, + }, + Resolve: func(p graphql.ResolveParams) (interface{}, error) { + if color, ok := p.Args["color"]; ok { + return color, nil + } + return nil, nil + }, + }, + }, +}) + var enumTypeTestSchema, _ = graphql.NewSchema(graphql.SchemaConfig{ - Query: enumTypeTestQueryType, - Mutation: enumTypeTestMutationType, + Query: enumTypeTestQueryType, + Mutation: enumTypeTestMutationType, + Subscription: enumTypeTestSubscriptionType, }) func executeEnumTypeTest(t *testing.T, query string) *graphql.Result { @@ -156,7 +179,10 @@ func TestTypeSystem_EnumValues_DoesNotAcceptStringLiterals(t *testing.T) { Data: nil, Errors: []gqlerrors.FormattedError{ gqlerrors.FormattedError{ - Message: `Argument "fromEnum" expected type "Color" but got: "GREEN".`, + Message: "Argument \"fromEnum\" has invalid value \"GREEN\".\nExpected type \"Color\", found \"GREEN\".", + Locations: []location.SourceLocation{ + {Line: 1, Column: 23}, + }, }, }, } @@ -183,7 +209,10 @@ func TestTypeSystem_EnumValues_DoesNotAcceptInternalValueInPlaceOfEnumLiteral(t Data: nil, Errors: []gqlerrors.FormattedError{ gqlerrors.FormattedError{ - Message: `Argument "fromEnum" expected type "Color" but got: 1.`, + Message: "Argument \"fromEnum\" has invalid value 1.\nExpected type \"Color\", found 1.", + Locations: []location.SourceLocation{ + {Line: 1, Column: 23}, + }, }, }, } @@ -199,7 +228,10 @@ func TestTypeSystem_EnumValues_DoesNotAcceptEnumLiteralInPlaceOfInt(t *testing.T Data: nil, Errors: []gqlerrors.FormattedError{ gqlerrors.FormattedError{ - Message: `Argument "fromInt" expected type "Int" but got: GREEN.`, + Message: "Argument \"fromInt\" has invalid value GREEN.\nExpected type \"Int\", found GREEN.", + Locations: []location.SourceLocation{ + {Line: 1, Column: 23}, + }, }, }, } @@ -240,6 +272,22 @@ func TestTypeSystem_EnumValues_AcceptsEnumLiteralsAsInputArgumentsToMutations(t t.Fatalf("Unexpected result, Diff: %v", testutil.Diff(expected, result)) } } + +func TestTypeSystem_EnumValues_AcceptsEnumLiteralsAsInputArgumentsToSubscriptions(t *testing.T) { + query := `subscription x($color: Color!) { subscribeToEnum(color: $color) }` + params := map[string]interface{}{ + "color": "GREEN", + } + expected := &graphql.Result{ + Data: map[string]interface{}{ + "subscribeToEnum": "GREEN", + }, + } + result := executeEnumTypeTestWithParams(t, query, params) + if !reflect.DeepEqual(expected, result) { + t.Fatalf("Unexpected result, Diff: %v", testutil.Diff(expected, result)) + } +} func TestTypeSystem_EnumValues_DoesNotAcceptInternalValueAsEnumVariable(t *testing.T) { query := `query test($color: Color!) { colorEnum(fromEnum: $color) }` params := map[string]interface{}{ @@ -249,7 +297,10 @@ func TestTypeSystem_EnumValues_DoesNotAcceptInternalValueAsEnumVariable(t *testi Data: nil, Errors: []gqlerrors.FormattedError{ gqlerrors.FormattedError{ - Message: `Variable "$color" expected value of type "Color!" but got: 2.`, + Message: "Variable \"$color\" got invalid value 2.\nExpected type \"Color\", found \"2\".", + Locations: []location.SourceLocation{ + {Line: 1, Column: 12}, + }, }, }, } diff --git a/examples/todo/main.go b/examples/todo/main.go index 6a3f3c98..f5b69629 100644 --- a/examples/todo/main.go +++ b/examples/todo/main.go @@ -75,12 +75,12 @@ var rootMutation = graphql.NewObject(graphql.ObjectConfig{ text, _ := params.Args["text"].(string) // figure out new id - newId := RandStringRunes(8) + newID := RandStringRunes(8) // perform mutation operation here // for e.g. create a Todo and save to DB. newTodo := Todo{ - ID: newId, + ID: newID, Text: text, Done: false, } diff --git a/executor.go b/executor.go index fb4e14c1..b810f358 100644 --- a/executor.go +++ b/executor.go @@ -82,43 +82,34 @@ type ExecutionContext struct { func buildExecutionContext(p BuildExecutionCtxParams) (*ExecutionContext, error) { eCtx := &ExecutionContext{} - operations := map[string]ast.Definition{} + var operation *ast.OperationDefinition fragments := map[string]ast.Definition{} - for _, statement := range p.AST.Definitions { - switch stm := statement.(type) { + + for _, definition := range p.AST.Definitions { + switch definition := definition.(type) { case *ast.OperationDefinition: - key := "" - if stm.GetName() != nil && stm.GetName().Value != "" { - key = stm.GetName().Value + if (p.OperationName == "") && operation != nil { + return nil, errors.New("Must provide operation name if query contains multiple operations.") + } + if p.OperationName == "" || definition.GetName() != nil && definition.GetName().Value == p.OperationName { + operation = definition } - operations[key] = stm case *ast.FragmentDefinition: key := "" - if stm.GetName() != nil && stm.GetName().Value != "" { - key = stm.GetName().Value + if definition.GetName() != nil && definition.GetName().Value != "" { + key = definition.GetName().Value } - fragments[key] = stm + fragments[key] = definition default: - return nil, fmt.Errorf("GraphQL cannot execute a request containing a %v", statement.GetKind()) + return nil, fmt.Errorf("GraphQL cannot execute a request containing a %v", definition.GetKind()) } } - if (p.OperationName == "") && (len(operations) != 1) { - return nil, errors.New("Must provide operation name if query contains multiple operations.") - } - - opName := p.OperationName - if opName == "" { - // get first opName - for k, _ := range operations { - opName = k - break + if operation == nil { + if p.OperationName == "" { + return nil, fmt.Errorf(`Unknown operation named "%v".`, p.OperationName) } - } - - operation, found := operations[opName] - if !found { - return nil, fmt.Errorf(`Unknown operation named "%v".`, opName) + return nil, fmt.Errorf(`Must provide an operation`) } variableValues, err := getVariableValues(p.Schema, operation.GetVariableDefinitions(), p.Args) @@ -149,9 +140,9 @@ func executeOperation(p ExecuteOperationParams) *Result { } fields := collectFields(CollectFieldsParams{ - ExeContext: p.ExecutionContext, - OperationType: operationType, - SelectionSet: p.Operation.GetSelectionSet(), + ExeContext: p.ExecutionContext, + RuntimeType: operationType, + SelectionSet: p.Operation.GetSelectionSet(), }) executeFieldsParams := ExecuteFieldsParams{ @@ -163,9 +154,9 @@ func executeOperation(p ExecuteOperationParams) *Result { if p.Operation.GetOperation() == "mutation" { return executeFieldsSerially(executeFieldsParams) - } else { - return executeFields(executeFieldsParams) } + return executeFields(executeFieldsParams) + } // Extracts the root type of the operation from the schema. @@ -180,11 +171,38 @@ func getOperationRootType(schema Schema, operation ast.Definition) (*Object, err case "mutation": mutationType := schema.MutationType() if mutationType.PrivateName == "" { - return nil, errors.New("Schema is not configured for mutations") + return nil, gqlerrors.NewError( + "Schema is not configured for mutations", + []ast.Node{operation}, + "", + nil, + []int{}, + nil, + ) } return mutationType, nil + case "subscription": + subscriptionType := schema.SubscriptionType() + if subscriptionType.PrivateName == "" { + return nil, gqlerrors.NewError( + "Schema is not configured for subscriptions", + []ast.Node{operation}, + "", + nil, + []int{}, + nil, + ) + } + return subscriptionType, nil default: - return nil, errors.New("Can only execute queries and mutations") + return nil, gqlerrors.NewError( + "Can only execute queries, mutations and subscription", + []ast.Node{operation}, + "", + nil, + []int{}, + nil, + ) } } @@ -245,7 +263,7 @@ func executeFields(p ExecuteFieldsParams) *Result { type CollectFieldsParams struct { ExeContext *ExecutionContext - OperationType *Object + RuntimeType *Object // previously known as OperationType SelectionSet *ast.SelectionSet Fields map[string][]*ast.Field VisitedFragmentNames map[string]bool @@ -253,6 +271,9 @@ type CollectFieldsParams struct { // Given a selectionSet, adds all of the fields in that selection to // the passed in map of fields, and returns it at the end. +// CollectFields requires the "runtime type" of an object. For a field which +// returns and Interface or Union type, the "runtime type" will be the actual +// Object type returned by that field. func collectFields(p CollectFieldsParams) map[string][]*ast.Field { fields := p.Fields @@ -279,12 +300,12 @@ func collectFields(p CollectFieldsParams) map[string][]*ast.Field { case *ast.InlineFragment: if !shouldIncludeNode(p.ExeContext, selection.Directives) || - !doesFragmentConditionMatch(p.ExeContext, selection, p.OperationType) { + !doesFragmentConditionMatch(p.ExeContext, selection, p.RuntimeType) { continue } innerParams := CollectFieldsParams{ ExeContext: p.ExeContext, - OperationType: p.OperationType, + RuntimeType: p.RuntimeType, SelectionSet: selection.SelectionSet, Fields: fields, VisitedFragmentNames: p.VisitedFragmentNames, @@ -307,12 +328,12 @@ func collectFields(p CollectFieldsParams) map[string][]*ast.Field { if fragment, ok := fragment.(*ast.FragmentDefinition); ok { if !shouldIncludeNode(p.ExeContext, fragment.Directives) || - !doesFragmentConditionMatch(p.ExeContext, fragment, p.OperationType) { + !doesFragmentConditionMatch(p.ExeContext, fragment, p.RuntimeType) { continue } innerParams := CollectFieldsParams{ ExeContext: p.ExeContext, - OperationType: p.OperationType, + RuntimeType: p.RuntimeType, SelectionSet: fragment.GetSelectionSet(), Fields: fields, VisitedFragmentNames: p.VisitedFragmentNames, @@ -390,29 +411,38 @@ func doesFragmentConditionMatch(eCtx *ExecutionContext, fragment ast.Node, ttype switch fragment := fragment.(type) { case *ast.FragmentDefinition: - conditionalType, err := typeFromAST(eCtx.Schema, fragment.TypeCondition) + typeConditionAST := fragment.TypeCondition + if typeConditionAST == nil { + return true + } + conditionalType, err := typeFromAST(eCtx.Schema, typeConditionAST) if err != nil { return false } if conditionalType == ttype { return true } - if conditionalType.Name() == ttype.Name() { + if conditionalType.Name() == ttype.Name() { return true } - if conditionalType, ok := conditionalType.(Abstract); ok { return conditionalType.IsPossibleType(ttype) } case *ast.InlineFragment: - conditionalType, err := typeFromAST(eCtx.Schema, fragment.TypeCondition) + typeConditionAST := fragment.TypeCondition + if typeConditionAST == nil { + return true + } + conditionalType, err := typeFromAST(eCtx.Schema, typeConditionAST) if err != nil { return false } if conditionalType == ttype { return true } - + if conditionalType.Name() == ttype.Name() { + return true + } if conditionalType, ok := conditionalType.(Abstract); ok { return conditionalType.IsPossibleType(ttype) } @@ -438,12 +468,10 @@ type resolveFieldResultState struct { hasNoFieldDefs bool } -/** - * Resolves the field on the given source object. In particular, this - * figures out the value that the field returns by calling its resolve function, - * then calls completeValue to complete promises, serialize scalars, or execute - * the sub-selection-set for objects. - */ +// Resolves the field on the given source object. In particular, this +// figures out the value that the field returns by calling its resolve function, +// then calls completeValue to complete promises, serialize scalars, or execute +// the sub-selection-set for objects. func resolveField(eCtx *ExecutionContext, parentType *Object, source interface{}, fieldASTs []*ast.Field) (result interface{}, resultState resolveFieldResultState) { // catch panic from resolveFn var returnType Output @@ -506,10 +534,6 @@ func resolveField(eCtx *ExecutionContext, parentType *Object, source interface{} VariableValues: eCtx.VariableValues, } - // TODO: If an error occurs while calling the field `resolve` function, ensure that - // it is wrapped as a Error with locations. Log this error and return - // null if allowed, otherwise throw the error so the parent field can handle - // it. var resolveFnError error result, resolveFnError = resolveFn(ResolveParams{ @@ -592,9 +616,14 @@ func completeValue(eCtx *ExecutionContext, returnType Type, fieldASTs []*ast.Fie if returnType, ok := returnType.(*List); ok { resultVal := reflect.ValueOf(result) + parentTypeName := "" + if info.ParentType != nil { + parentTypeName = info.ParentType.Name() + } err := invariant( resultVal.IsValid() && resultVal.Type().Kind() == reflect.Slice, - "User Error: expected iterable, but did not find one.", + fmt.Sprintf("User Error: expected iterable, but did not find one "+ + "for field %v.%v.", parentTypeName, info.FieldName), ) if err != nil { panic(gqlerrors.FormatError(err)) @@ -628,29 +657,29 @@ func completeValue(eCtx *ExecutionContext, returnType Type, fieldASTs []*ast.Fie } // ast.Field type must be Object, Interface or Union and expect sub-selections. - var objectType *Object + var runtimeType *Object switch returnType := returnType.(type) { case *Object: - objectType = returnType + runtimeType = returnType case Abstract: - objectType = returnType.ObjectType(result, info) - if objectType != nil && !returnType.IsPossibleType(objectType) { + runtimeType = returnType.ObjectType(result, info) + if runtimeType != nil && !returnType.IsPossibleType(runtimeType) { panic(gqlerrors.NewFormattedError( fmt.Sprintf(`Runtime Object type "%v" is not a possible type `+ - `for "%v".`, objectType, returnType), + `for "%v".`, runtimeType, returnType), )) } } - if objectType == nil { + if runtimeType == nil { return nil } // If there is an isTypeOf predicate function, call it with the // current result. If isTypeOf returns false, then raise an error rather // than continuing execution. - if objectType.IsTypeOf != nil && !objectType.IsTypeOf(result, info) { + if runtimeType.IsTypeOf != nil && !runtimeType.IsTypeOf(result, info) { panic(gqlerrors.NewFormattedError( - fmt.Sprintf(`Expected value of type "%v" but got: %T.`, objectType, result), + fmt.Sprintf(`Expected value of type "%v" but got: %T.`, runtimeType, result), )) } @@ -665,7 +694,7 @@ func completeValue(eCtx *ExecutionContext, returnType Type, fieldASTs []*ast.Fie if selectionSet != nil { innerParams := CollectFieldsParams{ ExeContext: eCtx, - OperationType: objectType, + RuntimeType: runtimeType, SelectionSet: selectionSet, Fields: subFieldASTs, VisitedFragmentNames: visitedFragmentNames, @@ -675,7 +704,7 @@ func completeValue(eCtx *ExecutionContext, returnType Type, fieldASTs []*ast.Fie } executeFieldsParams := ExecuteFieldsParams{ ExecutionContext: eCtx, - ParentType: objectType, + ParentType: runtimeType, Source: result, Fields: subFieldASTs, } @@ -738,15 +767,13 @@ func defaultResolveFn(p ResolveParams) (interface{}, error) { return nil, nil } -/** - * This method looks up the field on the given type defintion. - * It has special casing for the two introspection fields, __schema - * and __typename. __typename is special because it can always be - * queried as a field, even in situations where no other fields - * are allowed, like on a Union. __schema could get automatically - * added to the query type, but that would require mutating type - * definitions, which would cause issues. - */ +// This method looks up the field on the given type defintion. +// It has special casing for the two introspection fields, __schema +// and __typename. __typename is special because it can always be +// queried as a field, even in situations where no other fields +// are allowed, like on a Union. __schema could get automatically +// added to the query type, but that would require mutating type +// definitions, which would cause issues. func getFieldDef(schema Schema, parentType *Object, fieldName string) *FieldDefinition { if parentType == nil { diff --git a/executor_test.go b/executor_test.go index 7922d8c8..f7915ec6 100644 --- a/executor_test.go +++ b/executor_test.go @@ -662,7 +662,7 @@ func TestThrowsIfNoOperationIsProvidedWithMultipleOperations(t *testing.T) { func TestUsesTheQuerySchemaForQueries(t *testing.T) { - doc := `query Q { a } mutation M { c }` + doc := `query Q { a } mutation M { c } subscription S { a }` data := map[string]interface{}{ "a": "b", "c": "d", @@ -691,6 +691,14 @@ func TestUsesTheQuerySchemaForQueries(t *testing.T) { }, }, }), + Subscription: graphql.NewObject(graphql.ObjectConfig{ + Name: "S", + Fields: graphql.Fields{ + "a": &graphql.Field{ + Type: graphql.String, + }, + }, + }), }) if err != nil { t.Fatalf("Error in schema %v", err.Error()) @@ -770,6 +778,61 @@ func TestUsesTheMutationSchemaForMutations(t *testing.T) { } } +func TestUsesTheSubscriptionSchemaForSubscriptions(t *testing.T) { + + doc := `query Q { a } subscription S { a }` + data := map[string]interface{}{ + "a": "b", + "c": "d", + } + + expected := &graphql.Result{ + Data: map[string]interface{}{ + "a": "b", + }, + } + + schema, err := graphql.NewSchema(graphql.SchemaConfig{ + Query: graphql.NewObject(graphql.ObjectConfig{ + Name: "Q", + Fields: graphql.Fields{ + "a": &graphql.Field{ + Type: graphql.String, + }, + }, + }), + Subscription: graphql.NewObject(graphql.ObjectConfig{ + Name: "S", + Fields: graphql.Fields{ + "a": &graphql.Field{ + Type: graphql.String, + }, + }, + }), + }) + if err != nil { + t.Fatalf("Error in schema %v", err.Error()) + } + + // parse query + ast := testutil.TestParse(t, doc) + + // execute + ep := graphql.ExecuteParams{ + Schema: schema, + AST: ast, + Root: data, + OperationName: "S", + } + result := testutil.TestExecute(t, ep) + if len(result.Errors) > 0 { + t.Fatalf("wrong result, unexpected errors: %v", result.Errors) + } + if !reflect.DeepEqual(expected, result) { + t.Fatalf("Unexpected result, Diff: %v", testutil.Diff(expected, result)) + } +} + func TestCorrectFieldOrderingDespiteExecutionOrder(t *testing.T) { doc := ` diff --git a/gqlerrors/error.go b/gqlerrors/error.go index c32fff3c..1809870a 100644 --- a/gqlerrors/error.go +++ b/gqlerrors/error.go @@ -9,12 +9,13 @@ import ( ) type Error struct { - Message string - Stack string - Nodes []ast.Node - Source *source.Source - Positions []int - Locations []location.SourceLocation + Message string + Stack string + Nodes []ast.Node + Source *source.Source + Positions []int + Locations []location.SourceLocation + OriginalError error } // implements Golang's built-in `error` interface @@ -22,7 +23,7 @@ func (g Error) Error() string { return fmt.Sprintf("%v", g.Message) } -func NewError(message string, nodes []ast.Node, stack string, source *source.Source, positions []int) *Error { +func NewError(message string, nodes []ast.Node, stack string, source *source.Source, positions []int, origError error) *Error { if stack == "" && message != "" { stack = message } @@ -49,11 +50,12 @@ func NewError(message string, nodes []ast.Node, stack string, source *source.Sou locations = append(locations, loc) } return &Error{ - Message: message, - Stack: stack, - Nodes: nodes, - Source: source, - Positions: positions, - Locations: locations, + Message: message, + Stack: stack, + Nodes: nodes, + Source: source, + Positions: positions, + Locations: locations, + OriginalError: origError, } } diff --git a/gqlerrors/located.go b/gqlerrors/located.go index d5d1b020..b02fcd8a 100644 --- a/gqlerrors/located.go +++ b/gqlerrors/located.go @@ -1,16 +1,23 @@ package gqlerrors import ( + "errors" "github.com/graphql-go/graphql/language/ast" ) +// NewLocatedError creates a graphql.Error with location info +// @deprecated 0.4.18 +// Already exists in `graphql.NewLocatedError()` func NewLocatedError(err interface{}, nodes []ast.Node) *Error { + var origError error message := "An unknown error occurred." if err, ok := err.(error); ok { message = err.Error() + origError = err } if err, ok := err.(string); ok { message = err + origError = errors.New(err) } stack := message return NewError( @@ -19,6 +26,7 @@ func NewLocatedError(err interface{}, nodes []ast.Node) *Error { stack, nil, []int{}, + origError, ) } diff --git a/gqlerrors/syntax.go b/gqlerrors/syntax.go index 76a39751..4235a040 100644 --- a/gqlerrors/syntax.go +++ b/gqlerrors/syntax.go @@ -7,6 +7,7 @@ import ( "github.com/graphql-go/graphql/language/ast" "github.com/graphql-go/graphql/language/location" "github.com/graphql-go/graphql/language/source" + "strings" ) func NewSyntaxError(s *source.Source, position int, description string) *Error { @@ -17,27 +18,44 @@ func NewSyntaxError(s *source.Source, position int, description string) *Error { "", s, []int{position}, + nil, ) } +// printCharCode here is slightly different from lexer.printCharCode() +func printCharCode(code rune) string { + // print as ASCII for printable range + if code >= 0x0020 { + return fmt.Sprintf(`%c`, code) + } + // Otherwise print the escaped form. e.g. `"\\u0007"` + return fmt.Sprintf(`\u%04X`, code) +} +func printLine(str string) string { + strSlice := []string{} + for _, runeValue := range str { + strSlice = append(strSlice, printCharCode(runeValue)) + } + return fmt.Sprintf(`%s`, strings.Join(strSlice, "")) +} func highlightSourceAtLocation(s *source.Source, l location.SourceLocation) string { line := l.Line prevLineNum := fmt.Sprintf("%d", (line - 1)) lineNum := fmt.Sprintf("%d", line) nextLineNum := fmt.Sprintf("%d", (line + 1)) padLen := len(nextLineNum) - lines := regexp.MustCompile("\r\n|[\n\r\u2028\u2029]").Split(s.Body, -1) + lines := regexp.MustCompile("\r\n|[\n\r]").Split(s.Body, -1) var highlight string if line >= 2 { - highlight += fmt.Sprintf("%s: %s\n", lpad(padLen, prevLineNum), lines[line-2]) + highlight += fmt.Sprintf("%s: %s\n", lpad(padLen, prevLineNum), printLine(lines[line-2])) } - highlight += fmt.Sprintf("%s: %s\n", lpad(padLen, lineNum), lines[line-1]) + highlight += fmt.Sprintf("%s: %s\n", lpad(padLen, lineNum), printLine(lines[line-1])) for i := 1; i < (2 + padLen + l.Column); i++ { highlight += " " } highlight += "^\n" if line < len(lines) { - highlight += fmt.Sprintf("%s: %s\n", lpad(padLen, nextLineNum), lines[line]) + highlight += fmt.Sprintf("%s: %s\n", lpad(padLen, nextLineNum), printLine(lines[line])) } return highlight } diff --git a/graphql.go b/graphql.go index db6b86ab..047535b4 100644 --- a/graphql.go +++ b/graphql.go @@ -8,11 +8,24 @@ import ( ) type Params struct { - Schema Schema - RequestString string - RootObject map[string]interface{} + // The GraphQL type system to use when validating and executing a query. + Schema Schema + + // A GraphQL language formatted string representing the requested operation. + RequestString string + + // The value provided as the first argument to resolver functions on the top + // level type (e.g. the query object type). + RootObject map[string]interface{} + + // A mapping of variable name to runtime value to use for all variables + // defined in the requestString. VariableValues map[string]interface{} - OperationName string + + // The name of the operation to use if requestString contains multiple + // possible operations. Can be omitted if requestString contains only + // one operation. + OperationName string // Context may be provided to pass application-specific per-request // information to resolve functions. diff --git a/introspection.go b/introspection.go index 81bc61f4..99e7c04e 100644 --- a/introspection.go +++ b/introspection.go @@ -19,14 +19,14 @@ const ( TypeKindNonNull = "NON_NULL" ) -var __Directive *Object -var __Schema *Object -var __Type *Object -var __Field *Object -var __InputValue *Object -var __EnumValue *Object +var directiveType *Object +var schemaType *Object +var typeType *Object +var fieldType *Object +var inputValueType *Object +var enumValueType *Object -var __TypeKind *Enum +var typeKindEnum *Enum var SchemaMetaFieldDef *FieldDefinition var TypeMetaFieldDef *FieldDefinition @@ -34,9 +34,9 @@ var TypeNameMetaFieldDef *FieldDefinition func init() { - __TypeKind = NewEnum(EnumConfig{ + typeKindEnum = NewEnum(EnumConfig{ Name: "__TypeKind", - Description: "An enum describing what kind of type a given __Type is", + Description: "An enum describing what kind of type a given `__Type` is", Values: EnumValueConfigMap{ "SCALAR": &EnumValueConfig{ Value: TypeKindScalar, @@ -81,11 +81,20 @@ func init() { }) // Note: some fields (for e.g "fields", "interfaces") are defined later due to cyclic reference - __Type = NewObject(ObjectConfig{ + typeType = NewObject(ObjectConfig{ Name: "__Type", + Description: "The fundamental unit of any GraphQL Schema is the type. There are " + + "many kinds of types in GraphQL as represented by the `__TypeKind` enum." + + "\n\nDepending on the kind of a type, certain fields describe " + + "information about that type. Scalar types provide no information " + + "beyond a name and description, while Enum types provide their values. " + + "Object and Interface types provide the fields they describe. Abstract " + + "types, Union and Interface, provide the Object types possible " + + "at runtime. List and NonNull types compose other types.", + Fields: Fields{ "kind": &Field{ - Type: NewNonNull(__TypeKind), + Type: NewNonNull(typeKindEnum), Resolve: func(p ResolveParams) (interface{}, error) { switch p.Source.(type) { case *Scalar: @@ -123,8 +132,11 @@ func init() { }, }) - __InputValue = NewObject(ObjectConfig{ + inputValueType = NewObject(ObjectConfig{ Name: "__InputValue", + Description: "Arguments provided to Fields or Directives and the input fields of an " + + "InputObject are represented as Input Values which describe their type " + + "and optionally a default value.", Fields: Fields{ "name": &Field{ Type: NewNonNull(String), @@ -133,15 +145,20 @@ func init() { Type: String, }, "type": &Field{ - Type: NewNonNull(__Type), + Type: NewNonNull(typeType), }, "defaultValue": &Field{ Type: String, + Description: "A GraphQL-formatted string representing the default value for this " + + "input value.", Resolve: func(p ResolveParams) (interface{}, error) { if inputVal, ok := p.Source.(*Argument); ok { if inputVal.DefaultValue == nil { return nil, nil } + if isNullish(inputVal.DefaultValue) { + return nil, nil + } astVal := astFromValue(inputVal.DefaultValue, inputVal) return printer.Print(astVal), nil } @@ -158,8 +175,10 @@ func init() { }, }) - __Field = NewObject(ObjectConfig{ + fieldType = NewObject(ObjectConfig{ Name: "__Field", + Description: "Object and Interface types are described by a list of Fields, each of " + + "which has a name, potentially a list of arguments, and a return type.", Fields: Fields{ "name": &Field{ Type: NewNonNull(String), @@ -168,7 +187,7 @@ func init() { Type: String, }, "args": &Field{ - Type: NewNonNull(NewList(NewNonNull(__InputValue))), + Type: NewNonNull(NewList(NewNonNull(inputValueType))), Resolve: func(p ResolveParams) (interface{}, error) { if field, ok := p.Source.(*FieldDefinition); ok { return field.Args, nil @@ -177,7 +196,7 @@ func init() { }, }, "type": &Field{ - Type: NewNonNull(__Type), + Type: NewNonNull(typeType), }, "isDeprecated": &Field{ Type: NewNonNull(Boolean), @@ -194,8 +213,14 @@ func init() { }, }) - __Directive = NewObject(ObjectConfig{ + directiveType = NewObject(ObjectConfig{ Name: "__Directive", + Description: "A Directive provides a way to describe alternate runtime execution and " + + "type validation behavior in a GraphQL document. " + + "\n\nIn some cases, you need to provide options to alter GraphQL's " + + "execution behavior in ways field arguments will not suffice, such as " + + "conditionally including or skipping a field. Directives provide this by " + + "describing additional information to the executor.", Fields: Fields{ "name": &Field{ Type: NewNonNull(String), @@ -205,7 +230,7 @@ func init() { }, "args": &Field{ Type: NewNonNull(NewList( - NewNonNull(__InputValue), + NewNonNull(inputValueType), )), }, "onOperation": &Field{ @@ -220,17 +245,16 @@ func init() { }, }) - __Schema = NewObject(ObjectConfig{ + schemaType = NewObject(ObjectConfig{ Name: "__Schema", - Description: `A GraphQL Schema defines the capabilities of a GraphQL -server. It exposes all available types and directives on -the server, as well as the entry points for query and -mutation operations.`, + Description: `A GraphQL Schema defines the capabilities of a GraphQL server. ` + + `It exposes all available types and directives on the server, as well as ` + + `the entry points for query, mutation, and subscription operations.`, Fields: Fields{ "types": &Field{ Description: "A list of all types supported by this server.", Type: NewNonNull(NewList( - NewNonNull(__Type), + NewNonNull(typeType), )), Resolve: func(p ResolveParams) (interface{}, error) { if schema, ok := p.Source.(Schema); ok { @@ -245,7 +269,7 @@ mutation operations.`, }, "queryType": &Field{ Description: "The type that query operations will be rooted at.", - Type: NewNonNull(__Type), + Type: NewNonNull(typeType), Resolve: func(p ResolveParams) (interface{}, error) { if schema, ok := p.Source.(Schema); ok { return schema.QueryType(), nil @@ -256,7 +280,7 @@ mutation operations.`, "mutationType": &Field{ Description: `If this server supports mutation, the type that ` + `mutation operations will be rooted at.`, - Type: __Type, + Type: typeType, Resolve: func(p ResolveParams) (interface{}, error) { if schema, ok := p.Source.(Schema); ok { if schema.MutationType() != nil { @@ -267,17 +291,22 @@ mutation operations.`, }, }, "subscriptionType": &Field{ - Description: `If this server support subscription, the type that ' + - 'subscription operations will be rooted at.`, - Type: __Type, + Description: `If this server supports subscription, the type that ` + + `subscription operations will be rooted at.`, + Type: typeType, Resolve: func(p ResolveParams) (interface{}, error) { + if schema, ok := p.Source.(Schema); ok { + if schema.SubscriptionType() != nil { + return schema.SubscriptionType(), nil + } + } return nil, nil }, }, "directives": &Field{ Description: `A list of all directives supported by this server.`, Type: NewNonNull(NewList( - NewNonNull(__Directive), + NewNonNull(directiveType), )), Resolve: func(p ResolveParams) (interface{}, error) { if schema, ok := p.Source.(Schema); ok { @@ -289,8 +318,11 @@ mutation operations.`, }, }) - __EnumValue = NewObject(ObjectConfig{ + enumValueType = NewObject(ObjectConfig{ Name: "__EnumValue", + Description: "One possible value for a given Enum. Enum values are unique values, not " + + "a placeholder for a string or numeric value. However an Enum value is " + + "returned in a JSON response as a string.", Fields: Fields{ "name": &Field{ Type: NewNonNull(String), @@ -315,8 +347,8 @@ mutation operations.`, // Again, adding field configs to __Type that have cyclic reference here // because golang don't like them too much during init/compile-time - __Type.AddFieldConfig("fields", &Field{ - Type: NewList(NewNonNull(__Field)), + typeType.AddFieldConfig("fields", &Field{ + Type: NewList(NewNonNull(fieldType)), Args: FieldConfigArgument{ "includeDeprecated": &ArgumentConfig{ Type: Boolean, @@ -354,8 +386,8 @@ mutation operations.`, return nil, nil }, }) - __Type.AddFieldConfig("interfaces", &Field{ - Type: NewList(NewNonNull(__Type)), + typeType.AddFieldConfig("interfaces", &Field{ + Type: NewList(NewNonNull(typeType)), Resolve: func(p ResolveParams) (interface{}, error) { switch ttype := p.Source.(type) { case *Object: @@ -364,8 +396,8 @@ mutation operations.`, return nil, nil }, }) - __Type.AddFieldConfig("possibleTypes", &Field{ - Type: NewList(NewNonNull(__Type)), + typeType.AddFieldConfig("possibleTypes", &Field{ + Type: NewList(NewNonNull(typeType)), Resolve: func(p ResolveParams) (interface{}, error) { switch ttype := p.Source.(type) { case *Interface: @@ -376,8 +408,8 @@ mutation operations.`, return nil, nil }, }) - __Type.AddFieldConfig("enumValues", &Field{ - Type: NewList(NewNonNull(__EnumValue)), + typeType.AddFieldConfig("enumValues", &Field{ + Type: NewList(NewNonNull(enumValueType)), Args: FieldConfigArgument{ "includeDeprecated": &ArgumentConfig{ Type: Boolean, @@ -403,8 +435,8 @@ mutation operations.`, return nil, nil }, }) - __Type.AddFieldConfig("inputFields", &Field{ - Type: NewList(NewNonNull(__InputValue)), + typeType.AddFieldConfig("inputFields", &Field{ + Type: NewList(NewNonNull(inputValueType)), Resolve: func(p ResolveParams) (interface{}, error) { switch ttype := p.Source.(type) { case *InputObject: @@ -417,18 +449,15 @@ mutation operations.`, return nil, nil }, }) - __Type.AddFieldConfig("ofType", &Field{ - Type: __Type, + typeType.AddFieldConfig("ofType", &Field{ + Type: typeType, }) - /** - * Note that these are FieldDefinition and not FieldConfig, - * so the format for args is different. - */ - + // Note that these are FieldDefinition and not FieldConfig, + // so the format for args is different. d SchemaMetaFieldDef = &FieldDefinition{ Name: "__schema", - Type: NewNonNull(__Schema), + Type: NewNonNull(schemaType), Description: "Access the current type schema of this server.", Args: []*Argument{}, Resolve: func(p ResolveParams) (interface{}, error) { @@ -437,7 +466,7 @@ mutation operations.`, } TypeMetaFieldDef = &FieldDefinition{ Name: "__type", - Type: __Type, + Type: typeType, Description: "Request the type information of a single type.", Args: []*Argument{ &Argument{ @@ -466,21 +495,19 @@ mutation operations.`, } -/** - * Produces a GraphQL Value AST given a Golang value. - * - * Optionally, a GraphQL type may be provided, which will be used to - * disambiguate between value primitives. - * - * | JSON Value | GraphQL Value | - * | ------------- | -------------------- | - * | Object | Input Object | - * | Array | List | - * | Boolean | Boolean | - * | String | String / Enum Value | - * | Number | Int / Float | - * - */ +// Produces a GraphQL Value AST given a Golang value. +// +// Optionally, a GraphQL type may be provided, which will be used to +// disambiguate between value primitives. +// +// | JSON Value | GraphQL Value | +// | ------------- | -------------------- | +// | Object | Input Object | +// | Array | List | +// | Boolean | Boolean | +// | String | String / Enum Value | +// | Number | Int / Float | + func astFromValue(value interface{}, ttype Type) ast.Value { if ttype, ok := ttype.(*NonNull); ok { @@ -519,13 +546,12 @@ func astFromValue(value interface{}, ttype Type) ast.Value { return ast.NewListValue(&ast.ListValue{ Values: values, }) - } else { - // Because GraphQL will accept single values as a "list of one" when - // expecting a list, if there's a non-array value and an expected list type, - // create an AST using the list's item type. - val := astFromValue(value, ttype.OfType) - return val } + // Because GraphQL will accept single values as a "list of one" when + // expecting a list, if there's a non-array value and an expected list type, + // create an AST using the list's item type. + val := astFromValue(value, ttype.OfType) + return val } if valueVal.Type().Kind() == reflect.Map { diff --git a/introspection_test.go b/introspection_test.go index eabcfc62..6425b9db 100644 --- a/introspection_test.go +++ b/introspection_test.go @@ -30,7 +30,8 @@ func TestIntrospection_ExecutesAnIntrospectionQuery(t *testing.T) { } expectedDataSubSet := map[string]interface{}{ "__schema": map[string]interface{}{ - "mutationType": nil, + "mutationType": nil, + "subscriptionType": nil, "queryType": map[string]interface{}{ "name": "QueryRoot", }, @@ -93,6 +94,16 @@ func TestIntrospection_ExecutesAnIntrospectionQuery(t *testing.T) { "isDeprecated": false, "deprecationReason": nil, }, + map[string]interface{}{ + "name": "subscriptionType", + "args": []interface{}{}, + "type": map[string]interface{}{ + "kind": "OBJECT", + "name": "__Type", + }, + "isDeprecated": false, + "deprecationReason": nil, + }, map[string]interface{}{ "name": "directives", "args": []interface{}{}, @@ -1257,14 +1268,15 @@ func TestIntrospection_ExposesDescriptionsOnTypesAndFields(t *testing.T) { } } ` + expected := &graphql.Result{ Data: map[string]interface{}{ "schemaType": map[string]interface{}{ "name": "__Schema", - "description": `A GraphQL Schema defines the capabilities of a GraphQL -server. It exposes all available types and directives on -the server, as well as the entry points for query and -mutation operations.`, + "description": `A GraphQL Schema defines the capabilities of a GraphQL ` + + `server. It exposes all available types and directives on ` + + `the server, as well as the entry points for query, mutation, ` + + `and subscription operations.`, "fields": []interface{}{ map[string]interface{}{ "name": "types", @@ -1279,6 +1291,11 @@ mutation operations.`, "description": "If this server supports mutation, the type that " + "mutation operations will be rooted at.", }, + map[string]interface{}{ + "name": "subscriptionType", + "description": "If this server supports subscription, the type that " + + "subscription operations will be rooted at.", + }, map[string]interface{}{ "name": "directives", "description": "A list of all directives supported by this server.", @@ -1327,7 +1344,7 @@ func TestIntrospection_ExposesDescriptionsOnEnums(t *testing.T) { Data: map[string]interface{}{ "typeKindType": map[string]interface{}{ "name": "__TypeKind", - "description": `An enum describing what kind of type a given __Type is`, + "description": "An enum describing what kind of type a given `__Type` is", "enumValues": []interface{}{ map[string]interface{}{ "name": "SCALAR", diff --git a/kitchen-sink.graphql b/kitchen-sink.graphql index 1f98edc9..d075edfd 100644 --- a/kitchen-sink.graphql +++ b/kitchen-sink.graphql @@ -12,6 +12,12 @@ query namedQuery($foo: ComplexFooType, $bar: Bar = DefaultBarValue) { } } } + ... @skip(unless: $foo) { + id + } + ... { + id + } } } @@ -23,6 +29,19 @@ mutation favPost { } } +subscription PostFavSubscription($input: StoryLikeSubscribeInput) { + postFavSubscribe(input: $input) { + post { + favers { + count + } + favSentence { + text + } + } + } +} + fragment frag on Follower { foo(size: $size, bar: $b, obj: {key: "value"}) } diff --git a/language/ast/definitions.go b/language/ast/definitions.go index 19d07ce5..a619e78d 100644 --- a/language/ast/definitions.go +++ b/language/ast/definitions.go @@ -9,11 +9,14 @@ type Definition interface { GetOperation() string GetVariableDefinitions() []*VariableDefinition GetSelectionSet() *SelectionSet + GetKind() string + GetLoc() *Location } // Ensure that all definition types implements Definition interface var _ Definition = (*OperationDefinition)(nil) var _ Definition = (*FragmentDefinition)(nil) +var _ Definition = (*TypeExtensionDefinition)(nil) var _ Definition = (Definition)(nil) // OperationDefinition implements Node, Definition @@ -151,3 +154,41 @@ func (vd *VariableDefinition) GetKind() string { func (vd *VariableDefinition) GetLoc() *Location { return vd.Loc } + +// TypeExtensionDefinition implements Node, Definition +type TypeExtensionDefinition struct { + Kind string + Loc *Location + Definition *ObjectDefinition +} + +func NewTypeExtensionDefinition(def *TypeExtensionDefinition) *TypeExtensionDefinition { + if def == nil { + def = &TypeExtensionDefinition{} + } + return &TypeExtensionDefinition{ + Kind: kinds.TypeExtensionDefinition, + Loc: def.Loc, + Definition: def.Definition, + } +} + +func (def *TypeExtensionDefinition) GetKind() string { + return def.Kind +} + +func (def *TypeExtensionDefinition) GetLoc() *Location { + return def.Loc +} + +func (def *TypeExtensionDefinition) GetVariableDefinitions() []*VariableDefinition { + return []*VariableDefinition{} +} + +func (def *TypeExtensionDefinition) GetSelectionSet() *SelectionSet { + return &SelectionSet{} +} + +func (def *TypeExtensionDefinition) GetOperation() string { + return "" +} diff --git a/language/ast/node.go b/language/ast/node.go index f35eb21d..22879877 100644 --- a/language/ast/node.go +++ b/language/ast/node.go @@ -27,6 +27,7 @@ var _ Node = (*ListValue)(nil) var _ Node = (*ObjectValue)(nil) var _ Node = (*ObjectField)(nil) var _ Node = (*Directive)(nil) +var _ Node = (*Named)(nil) var _ Node = (*List)(nil) var _ Node = (*NonNull)(nil) var _ Node = (*ObjectDefinition)(nil) @@ -39,7 +40,3 @@ var _ Node = (*EnumDefinition)(nil) var _ Node = (*EnumValueDefinition)(nil) var _ Node = (*InputObjectDefinition)(nil) var _ Node = (*TypeExtensionDefinition)(nil) - -// TODO: File issue in `graphql-js` where Named is not -// defined as a Node. This might be a mistake in `graphql-js`? -var _ Node = (*Named)(nil) diff --git a/language/ast/selections.go b/language/ast/selections.go index 1b7e60d2..0dc0ea12 100644 --- a/language/ast/selections.go +++ b/language/ast/selections.go @@ -5,6 +5,7 @@ import ( ) type Selection interface { + GetSelectionSet() *SelectionSet } // Ensure that all definition types implements Selection interface @@ -46,6 +47,10 @@ func (f *Field) GetLoc() *Location { return f.Loc } +func (f *Field) GetSelectionSet() *SelectionSet { + return f.SelectionSet +} + // FragmentSpread implements Node, Selection type FragmentSpread struct { Kind string @@ -74,6 +79,10 @@ func (fs *FragmentSpread) GetLoc() *Location { return fs.Loc } +func (fs *FragmentSpread) GetSelectionSet() *SelectionSet { + return nil +} + // InlineFragment implements Node, Selection type InlineFragment struct { Kind string @@ -104,6 +113,10 @@ func (f *InlineFragment) GetLoc() *Location { return f.Loc } +func (f *InlineFragment) GetSelectionSet() *SelectionSet { + return f.SelectionSet +} + // SelectionSet implements Node type SelectionSet struct { Kind string diff --git a/language/ast/type_definitions.go b/language/ast/type_definitions.go index 7af1d861..95810070 100644 --- a/language/ast/type_definitions.go +++ b/language/ast/type_definitions.go @@ -11,7 +11,6 @@ var _ Definition = (*UnionDefinition)(nil) var _ Definition = (*ScalarDefinition)(nil) var _ Definition = (*EnumDefinition)(nil) var _ Definition = (*InputObjectDefinition)(nil) -var _ Definition = (*TypeExtensionDefinition)(nil) // ObjectDefinition implements Node, Definition type ObjectDefinition struct { @@ -362,41 +361,3 @@ func (def *InputObjectDefinition) GetSelectionSet() *SelectionSet { func (def *InputObjectDefinition) GetOperation() string { return "" } - -// TypeExtensionDefinition implements Node, Definition -type TypeExtensionDefinition struct { - Kind string - Loc *Location - Definition *ObjectDefinition -} - -func NewTypeExtensionDefinition(def *TypeExtensionDefinition) *TypeExtensionDefinition { - if def == nil { - def = &TypeExtensionDefinition{} - } - return &TypeExtensionDefinition{ - Kind: kinds.TypeExtensionDefinition, - Loc: def.Loc, - Definition: def.Definition, - } -} - -func (def *TypeExtensionDefinition) GetKind() string { - return def.Kind -} - -func (def *TypeExtensionDefinition) GetLoc() *Location { - return def.Loc -} - -func (def *TypeExtensionDefinition) GetVariableDefinitions() []*VariableDefinition { - return []*VariableDefinition{} -} - -func (def *TypeExtensionDefinition) GetSelectionSet() *SelectionSet { - return &SelectionSet{} -} - -func (def *TypeExtensionDefinition) GetOperation() string { - return "" -} diff --git a/language/lexer/lexer.go b/language/lexer/lexer.go index 7b55c37c..4c149c34 100644 --- a/language/lexer/lexer.go +++ b/language/lexer/lexer.go @@ -23,7 +23,6 @@ const ( PIPE BRACE_R NAME - VARIABLE INT FLOAT STRING @@ -50,7 +49,6 @@ func init() { TokenKind[PIPE] = PIPE TokenKind[BRACE_R] = BRACE_R TokenKind[NAME] = NAME - TokenKind[VARIABLE] = VARIABLE TokenKind[INT] = INT TokenKind[FLOAT] = FLOAT TokenKind[STRING] = STRING @@ -69,12 +67,13 @@ func init() { tokenDescription[TokenKind[PIPE]] = "|" tokenDescription[TokenKind[BRACE_R]] = "}" tokenDescription[TokenKind[NAME]] = "Name" - tokenDescription[TokenKind[VARIABLE]] = "Variable" tokenDescription[TokenKind[INT]] = "Int" tokenDescription[TokenKind[FLOAT]] = "Float" tokenDescription[TokenKind[STRING]] = "String" } +// Token is a representation of a lexed Token. Value only appears for non-punctuation +// tokens: NAME, INT, FLOAT, and STRING. type Token struct { Kind int Start int @@ -103,6 +102,12 @@ func Lex(s *source.Source) Lexer { } } +func runeStringValueAt(body string, start, end int) string { + // convert body string to runes, to handle unicode + bodyRunes := []rune(body) + return string(bodyRunes[start:end]) +} + // Reads an alphanumeric + underscore name from the source. // [_A-Za-z][_0-9A-Za-z]* func readName(source *source.Source, position int) Token { @@ -111,17 +116,18 @@ func readName(source *source.Source, position int) Token { end := position + 1 for { code := charCodeAt(body, end) - if (end != bodyLength) && (code == 95 || - code >= 48 && code <= 57 || - code >= 65 && code <= 90 || - code >= 97 && code <= 122) { - end += 1 + if (end != bodyLength) && code != 0 && + (code == 95 || // _ + code >= 48 && code <= 57 || // 0-9 + code >= 65 && code <= 90 || // A-Z + code >= 97 && code <= 122) { // a-z + end++ continue } else { break } } - return makeToken(TokenKind[NAME], position, end, body[position:end]) + return makeToken(TokenKind[NAME], position, end, runeStringValueAt(body, position, end)) } // Reads a number token from the source file, either a float @@ -134,14 +140,14 @@ func readNumber(s *source.Source, start int, firstCode rune) (Token, error) { position := start isFloat := false if code == 45 { // - - position += 1 + position++ code = charCodeAt(body, position) } if code == 48 { // 0 - position += 1 + position++ code = charCodeAt(body, position) if code >= 48 && code <= 57 { - description := fmt.Sprintf("Invalid number, unexpected digit after 0: \"%c\".", code) + description := fmt.Sprintf("Invalid number, unexpected digit after 0: %v.", printCharCode(code)) return Token{}, gqlerrors.NewSyntaxError(s, position, description) } } else { @@ -154,7 +160,7 @@ func readNumber(s *source.Source, start int, firstCode rune) (Token, error) { } if code == 46 { // . isFloat = true - position += 1 + position++ code = charCodeAt(body, position) p, err := readDigits(s, position, code) if err != nil { @@ -165,10 +171,10 @@ func readNumber(s *source.Source, start int, firstCode rune) (Token, error) { } if code == 69 || code == 101 { // E e isFloat = true - position += 1 + position++ code = charCodeAt(body, position) if code == 43 || code == 45 { // + - - position += 1 + position++ code = charCodeAt(body, position) } p, err := readDigits(s, position, code) @@ -181,7 +187,7 @@ func readNumber(s *source.Source, start int, firstCode rune) (Token, error) { if isFloat { kind = TokenKind[FLOAT] } - return makeToken(kind, start, position, body[start:position]), nil + return makeToken(kind, start, position, runeStringValueAt(body, start, position)), nil } // Returns the new position in the source after reading digits. @@ -192,7 +198,7 @@ func readDigits(s *source.Source, start int, firstCode rune) (int, error) { if code >= 48 && code <= 57 { // 0 - 9 for { if code >= 48 && code <= 57 { // 0 - 9 - position += 1 + position++ code = charCodeAt(body, position) continue } else { @@ -202,11 +208,7 @@ func readDigits(s *source.Source, start int, firstCode rune) (int, error) { return position, nil } var description string - if code != 0 { - description = fmt.Sprintf("Invalid number, expected digit but got: \"%c\".", code) - } else { - description = fmt.Sprintf("Invalid number, expected digit but got: EOF.") - } + description = fmt.Sprintf("Invalid number, expected digit but got: %v.", printCharCode(code)) return position, gqlerrors.NewSyntaxError(s, position, description) } @@ -218,8 +220,17 @@ func readString(s *source.Source, start int) (Token, error) { var value string for { code = charCodeAt(body, position) - if position < len(body) && code != 34 && code != 10 && code != 13 && code != 0x2028 && code != 0x2029 { - position += 1 + if position < len(body) && + // not LineTerminator + code != 0x000A && code != 0x000D && + // not Quote (") + code != 34 { + + // SourceCharacter + if code < 0x0020 && code != 0x0009 { + return Token{}, gqlerrors.NewSyntaxError(s, position, fmt.Sprintf(`Invalid character within String: %v.`, printCharCode(code))) + } + position++ if code == 92 { // \ value += body[chunkStart : position-1] code = charCodeAt(body, position) @@ -248,7 +259,7 @@ func readString(s *source.Source, start int) (Token, error) { case 116: value += "\t" break - case 117: + case 117: // u charCode := uniCharCode( charCodeAt(body, position+1), charCodeAt(body, position+2), @@ -256,15 +267,18 @@ func readString(s *source.Source, start int) (Token, error) { charCodeAt(body, position+4), ) if charCode < 0 { - return Token{}, gqlerrors.NewSyntaxError(s, position, "Bad character escape sequence.") + return Token{}, gqlerrors.NewSyntaxError(s, position, + fmt.Sprintf("Invalid character escape sequence: "+ + "\\u%v", body[position+1:position+5])) } value += fmt.Sprintf("%c", charCode) position += 4 break default: - return Token{}, gqlerrors.NewSyntaxError(s, position, "Bad character escape sequence.") + return Token{}, gqlerrors.NewSyntaxError(s, position, + fmt.Sprintf(`Invalid character escape sequence: \\%c.`, code)) } - position += 1 + position++ chunkStart = position } continue @@ -272,10 +286,10 @@ func readString(s *source.Source, start int) (Token, error) { break } } - if code != 34 { + if code != 34 { // quote (") return Token{}, gqlerrors.NewSyntaxError(s, position, "Unterminated string.") } - value += body[chunkStart:position] + value += runeStringValueAt(body, chunkStart, position) return makeToken(TokenKind[STRING], start, position+1, value), nil } @@ -299,25 +313,44 @@ func char2hex(a rune) int { return int(a) - 48 } else if a >= 65 && a <= 70 { // A-F return int(a) - 55 - } else if a >= 97 && a <= 102 { // a-f + } else if a >= 97 && a <= 102 { + // a-f return int(a) - 87 - } else { - return -1 } + return -1 } func makeToken(kind int, start int, end int, value string) Token { return Token{Kind: kind, Start: start, End: end, Value: value} } +func printCharCode(code rune) string { + // NaN/undefined represents access beyond the end of the file. + if code < 0 { + return "" + } + // print as ASCII for printable range + if code >= 0x0020 && code < 0x007F { + return fmt.Sprintf(`"%c"`, code) + } + // Otherwise print the escaped form. e.g. `"\\u0007"` + return fmt.Sprintf(`"\\u%04X"`, code) +} + func readToken(s *source.Source, fromPosition int) (Token, error) { body := s.Body bodyLength := len(body) position := positionAfterWhitespace(body, fromPosition) - code := charCodeAt(body, position) if position >= bodyLength { return makeToken(TokenKind[EOF], position, position, ""), nil } + code := charCodeAt(body, position) + + // SourceCharacter + if code < 0x0020 && code != 0x0009 && code != 0x000A && code != 0x000D { + return Token{}, gqlerrors.NewSyntaxError(s, position, fmt.Sprintf(`Invalid character %v`, printCharCode(code))) + } + switch code { // ! case 33: @@ -376,9 +409,8 @@ func readToken(s *source.Source, fromPosition int) (Token, error) { token, err := readNumber(s, position, code) if err != nil { return token, err - } else { - return token, nil } + return token, nil // " case 34: token, err := readString(s, position) @@ -387,7 +419,7 @@ func readToken(s *source.Source, fromPosition int) (Token, error) { } return token, nil } - description := fmt.Sprintf("Unexpected character \"%c\".", code) + description := fmt.Sprintf("Unexpected character %v.", printCharCode(code)) return Token{}, gqlerrors.NewSyntaxError(s, position, description) } @@ -395,9 +427,9 @@ func charCodeAt(body string, position int) rune { r := []rune(body) if len(r) > position { return r[position] - } else { - return 0 } + return -1 + } // Reads from body starting at startPosition until it finds a non-whitespace @@ -409,20 +441,27 @@ func positionAfterWhitespace(body string, startPosition int) int { for { if position < bodyLength { code := charCodeAt(body, position) - if code == 32 || // space - code == 44 || // comma - code == 160 || // '\xa0' - code == 0x2028 || // line separator - code == 0x2029 || // paragraph separator - code > 8 && code < 14 { // whitespace - position += 1 + + // Skip Ignored + if code == 0xFEFF || // BOM + // White Space + code == 0x0009 || // tab + code == 0x0020 || // space + // Line Terminator + code == 0x000A || // new line + code == 0x000D || // carriage return + // Comma + code == 0x002C { + position++ } else if code == 35 { // # - position += 1 + position++ for { code := charCodeAt(body, position) if position < bodyLength && - code != 10 && code != 13 && code != 0x2028 && code != 0x2029 { - position += 1 + code != 0 && + // SourceCharacter but not LineTerminator + (code > 0x001F || code == 0x0009) && code != 0x000A && code != 0x000D { + position++ continue } else { break @@ -442,9 +481,8 @@ func positionAfterWhitespace(body string, startPosition int) int { func GetTokenDesc(token Token) string { if token.Value == "" { return GetTokenKindDesc(token.Kind) - } else { - return fmt.Sprintf("%s \"%s\"", GetTokenKindDesc(token.Kind), token.Value) } + return fmt.Sprintf("%s \"%s\"", GetTokenKindDesc(token.Kind), token.Value) } func GetTokenKindDesc(kind int) string { diff --git a/language/lexer/lexer_test.go b/language/lexer/lexer_test.go index 1db38bcf..13690d1e 100644 --- a/language/lexer/lexer_test.go +++ b/language/lexer/lexer_test.go @@ -16,6 +16,51 @@ func createSource(body string) *source.Source { return source.NewSource(&source.Source{Body: body}) } +func TestDisallowsUncommonControlCharacters(t *testing.T) { + tests := []Test{ + Test{ + Body: "\u0007", + Expected: `Syntax Error GraphQL (1:1) Invalid character "\\u0007" + +1: \u0007 + ^ +`, + }, + } + for _, test := range tests { + _, err := Lex(createSource(test.Body))(0) + if err == nil { + t.Fatalf("unexpected nil error\nexpected:\n%v\n\ngot:\n%v", test.Expected, err) + } + if err.Error() != test.Expected { + t.Fatalf("unexpected error.\nexpected:\n%v\n\ngot:\n%v", test.Expected, err.Error()) + } + } +} + +func TestAcceptsBOMHeader(t *testing.T) { + tests := []Test{ + Test{ + Body: "\uFEFF foo", + Expected: Token{ + Kind: TokenKind[NAME], + Start: 2, + End: 5, + Value: "foo", + }, + }, + } + for _, test := range tests { + token, err := Lex(&source.Source{Body: test.Body})(0) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if !reflect.DeepEqual(token, test.Expected) { + t.Fatalf("unexpected token, expected: %v, got: %v", test.Expected, token) + } + } +} + func TestSkipsWhiteSpace(t *testing.T) { tests := []Test{ Test{ @@ -150,6 +195,14 @@ func TestLexesStrings(t *testing.T) { func TestLexReportsUsefulStringErrors(t *testing.T) { tests := []Test{ + Test{ + Body: "\"", + Expected: `Syntax Error GraphQL (1:2) Unterminated string. + +1: " + ^ +`, + }, Test{ Body: "\"no end quote", Expected: `Syntax Error GraphQL (1:14) Unterminated string. @@ -159,25 +212,23 @@ func TestLexReportsUsefulStringErrors(t *testing.T) { `, }, Test{ - Body: "\"multi\nline\"", - Expected: `Syntax Error GraphQL (1:7) Unterminated string. + Body: "\"contains unescaped \u0007 control char\"", + Expected: `Syntax Error GraphQL (1:21) Invalid character within String: "\\u0007". -1: "multi - ^ -2: line" +1: "contains unescaped \u0007 control char" + ^ `, }, Test{ - Body: "\"multi\rline\"", - Expected: `Syntax Error GraphQL (1:7) Unterminated string. + Body: "\"null-byte is not \u0000 end of file\"", + Expected: `Syntax Error GraphQL (1:19) Invalid character within String: "\\u0000". -1: "multi - ^ -2: line" +1: "null-byte is not \u0000 end of file" + ^ `, }, Test{ - Body: "\"multi\u2028line\"", + Body: "\"multi\nline\"", Expected: `Syntax Error GraphQL (1:7) Unterminated string. 1: "multi @@ -186,7 +237,7 @@ func TestLexReportsUsefulStringErrors(t *testing.T) { `, }, Test{ - Body: "\"multi\u2029line\"", + Body: "\"multi\rline\"", Expected: `Syntax Error GraphQL (1:7) Unterminated string. 1: "multi @@ -196,7 +247,7 @@ func TestLexReportsUsefulStringErrors(t *testing.T) { }, Test{ Body: "\"bad \\z esc\"", - Expected: `Syntax Error GraphQL (1:7) Bad character escape sequence. + Expected: `Syntax Error GraphQL (1:7) Invalid character escape sequence: \\z. 1: "bad \z esc" ^ @@ -204,7 +255,7 @@ func TestLexReportsUsefulStringErrors(t *testing.T) { }, Test{ Body: "\"bad \\x esc\"", - Expected: `Syntax Error GraphQL (1:7) Bad character escape sequence. + Expected: `Syntax Error GraphQL (1:7) Invalid character escape sequence: \\x. 1: "bad \x esc" ^ @@ -212,7 +263,7 @@ func TestLexReportsUsefulStringErrors(t *testing.T) { }, Test{ Body: "\"bad \\u1 esc\"", - Expected: `Syntax Error GraphQL (1:7) Bad character escape sequence. + Expected: `Syntax Error GraphQL (1:7) Invalid character escape sequence: \u1 es 1: "bad \u1 esc" ^ @@ -220,7 +271,7 @@ func TestLexReportsUsefulStringErrors(t *testing.T) { }, Test{ Body: "\"bad \\u0XX1 esc\"", - Expected: `Syntax Error GraphQL (1:7) Bad character escape sequence. + Expected: `Syntax Error GraphQL (1:7) Invalid character escape sequence: \u0XX1 1: "bad \u0XX1 esc" ^ @@ -228,7 +279,7 @@ func TestLexReportsUsefulStringErrors(t *testing.T) { }, Test{ Body: "\"bad \\uXXXX esc\"", - Expected: `Syntax Error GraphQL (1:7) Bad character escape sequence. + Expected: `Syntax Error GraphQL (1:7) Invalid character escape sequence: \uXXXX 1: "bad \uXXXX esc" ^ @@ -236,7 +287,7 @@ func TestLexReportsUsefulStringErrors(t *testing.T) { }, Test{ Body: "\"bad \\uFXXX esc\"", - Expected: `Syntax Error GraphQL (1:7) Bad character escape sequence. + Expected: `Syntax Error GraphQL (1:7) Invalid character escape sequence: \uFXXX 1: "bad \uFXXX esc" ^ @@ -244,7 +295,7 @@ func TestLexReportsUsefulStringErrors(t *testing.T) { }, Test{ Body: "\"bad \\uXXXF esc\"", - Expected: `Syntax Error GraphQL (1:7) Bad character escape sequence. + Expected: `Syntax Error GraphQL (1:7) Invalid character escape sequence: \uXXXF 1: "bad \uXXXF esc" ^ @@ -440,7 +491,7 @@ func TestLexReportsUsefulNumbeErrors(t *testing.T) { }, Test{ Body: "1.", - Expected: `Syntax Error GraphQL (1:3) Invalid number, expected digit but got: EOF. + Expected: `Syntax Error GraphQL (1:3) Invalid number, expected digit but got: . 1: 1. ^ @@ -472,7 +523,8 @@ func TestLexReportsUsefulNumbeErrors(t *testing.T) { }, Test{ Body: "1.0e", - Expected: `Syntax Error GraphQL (1:5) Invalid number, expected digit but got: EOF. + + Expected: `Syntax Error GraphQL (1:5) Invalid number, expected digit but got: . 1: 1.0e ^ @@ -649,7 +701,15 @@ func TestLexReportsUsefulUnknownCharacterError(t *testing.T) { }, Test{ Body: "\u203B", - Expected: `Syntax Error GraphQL (1:1) Unexpected character "※". + Expected: `Syntax Error GraphQL (1:1) Unexpected character "\\u203B". + +1: ※ + ^ +`, + }, + Test{ + Body: "\u203b", + Expected: `Syntax Error GraphQL (1:1) Unexpected character "\\u203B". 1: ※ ^ diff --git a/language/location/location.go b/language/location/location.go index f0d47234..ec667caa 100644 --- a/language/location/location.go +++ b/language/location/location.go @@ -18,12 +18,12 @@ func GetLocation(s *source.Source, position int) SourceLocation { } line := 1 column := position + 1 - lineRegexp := regexp.MustCompile("\r\n|[\n\r\u2028\u2029]") + lineRegexp := regexp.MustCompile("\r\n|[\n\r]") matches := lineRegexp.FindAllStringIndex(body, -1) for _, match := range matches { matchIndex := match[0] if matchIndex < position { - line += 1 + line++ l := len(s.Body[match[0]:match[1]]) column = position + 1 - (matchIndex + l) continue diff --git a/language/parser/parser.go b/language/parser/parser.go index 45382418..7b480c3d 100644 --- a/language/parser/parser.go +++ b/language/parser/parser.go @@ -194,6 +194,13 @@ func parseDocument(parser *Parser) (*ast.Document, error) { /* Implements the parsing rules in the Operations section. */ +/** + * OperationDefinition : + * - SelectionSet + * - OperationType Name? VariableDefinitions? Directives? SelectionSet + * + * OperationType : one of query mutation + */ func parseOperationDefinition(parser *Parser) (*ast.OperationDefinition, error) { start := parser.Token.Start if peek(parser, lexer.TokenKind[lexer.BRACE_L]) { @@ -212,10 +219,20 @@ func parseOperationDefinition(parser *Parser) (*ast.OperationDefinition, error) if err != nil { return nil, err } - operation := operationToken.Value - name, err := parseName(parser) - if err != nil { - return nil, err + operation := "" + switch operationToken.Value { + case "mutation": + fallthrough + case "subscription": + fallthrough + case "query": + operation = operationToken.Value + default: + return nil, unexpected(parser, operationToken) + } + var name *ast.Name + if peek(parser, lexer.TokenKind[lexer.NAME]) { + name, err = parseName(parser) } variableDefinitions, err := parseVariableDefinitions(parser) if err != nil { @@ -239,6 +256,9 @@ func parseOperationDefinition(parser *Parser) (*ast.OperationDefinition, error) }), nil } +/** + * VariableDefinitions : ( VariableDefinition+ ) + */ func parseVariableDefinitions(parser *Parser) ([]*ast.VariableDefinition, error) { variableDefinitions := []*ast.VariableDefinition{} if peek(parser, lexer.TokenKind[lexer.PAREN_L]) { @@ -256,6 +276,9 @@ func parseVariableDefinitions(parser *Parser) ([]*ast.VariableDefinition, error) return variableDefinitions, nil } +/** + * VariableDefinition : Variable : Type DefaultValue? + */ func parseVariableDefinition(parser *Parser) (interface{}, error) { start := parser.Token.Start variable, err := parseVariable(parser) @@ -288,6 +311,9 @@ func parseVariableDefinition(parser *Parser) (interface{}, error) { }), nil } +/** + * Variable : $ Name + */ func parseVariable(parser *Parser) (*ast.Variable, error) { start := parser.Token.Start _, err := expect(parser, lexer.TokenKind[lexer.DOLLAR]) @@ -304,6 +330,9 @@ func parseVariable(parser *Parser) (*ast.Variable, error) { }), nil } +/** + * SelectionSet : { Selection+ } + */ func parseSelectionSet(parser *Parser) (*ast.SelectionSet, error) { start := parser.Token.Start iSelections, err := many(parser, lexer.TokenKind[lexer.BRACE_L], parseSelection, lexer.TokenKind[lexer.BRACE_R]) @@ -324,15 +353,25 @@ func parseSelectionSet(parser *Parser) (*ast.SelectionSet, error) { }), nil } +/** + * Selection : + * - Field + * - FragmentSpread + * - InlineFragment + */ func parseSelection(parser *Parser) (interface{}, error) { if peek(parser, lexer.TokenKind[lexer.SPREAD]) { r, err := parseFragment(parser) return r, err - } else { - return parseField(parser) } + return parseField(parser) } +/** + * Field : Alias? Name Arguments? Directives? SelectionSet? + * + * Alias : Name : + */ func parseField(parser *Parser) (*ast.Field, error) { start := parser.Token.Start nameOrAlias, err := parseName(parser) @@ -381,6 +420,9 @@ func parseField(parser *Parser) (*ast.Field, error) { }), nil } +/** + * Arguments : ( Argument+ ) + */ func parseArguments(parser *Parser) ([]*ast.Argument, error) { arguments := []*ast.Argument{} if peek(parser, lexer.TokenKind[lexer.PAREN_L]) { @@ -398,6 +440,9 @@ func parseArguments(parser *Parser) ([]*ast.Argument, error) { return arguments, nil } +/** + * Argument : Name : Value + */ func parseArgument(parser *Parser) (interface{}, error) { start := parser.Token.Start name, err := parseName(parser) @@ -421,17 +466,21 @@ func parseArgument(parser *Parser) (interface{}, error) { /* Implements the parsing rules in the Fragments section. */ +/** + * Corresponds to both FragmentSpread and InlineFragment in the spec. + * + * FragmentSpread : ... FragmentName Directives? + * + * InlineFragment : ... TypeCondition? Directives? SelectionSet + */ func parseFragment(parser *Parser) (interface{}, error) { start := parser.Token.Start _, err := expect(parser, lexer.TokenKind[lexer.SPREAD]) if err != nil { return nil, err } - if parser.Token.Value == "on" { - if err := advance(parser); err != nil { - return nil, err - } - name, err := parseNamed(parser) + if peek(parser, lexer.TokenKind[lexer.NAME]) && parser.Token.Value != "on" { + name, err := parseFragmentName(parser) if err != nil { return nil, err } @@ -439,32 +488,46 @@ func parseFragment(parser *Parser) (interface{}, error) { if err != nil { return nil, err } - selectionSet, err := parseSelectionSet(parser) + return ast.NewFragmentSpread(&ast.FragmentSpread{ + Name: name, + Directives: directives, + Loc: loc(parser, start), + }), nil + } + var typeCondition *ast.Named + if parser.Token.Value == "on" { + if err := advance(parser); err != nil { + return nil, err + } + name, err := parseNamed(parser) if err != nil { return nil, err } - return ast.NewInlineFragment(&ast.InlineFragment{ - TypeCondition: name, - Directives: directives, - SelectionSet: selectionSet, - Loc: loc(parser, start), - }), nil + typeCondition = name + } - name, err := parseFragmentName(parser) + directives, err := parseDirectives(parser) if err != nil { return nil, err } - directives, err := parseDirectives(parser) + selectionSet, err := parseSelectionSet(parser) if err != nil { return nil, err } - return ast.NewFragmentSpread(&ast.FragmentSpread{ - Name: name, - Directives: directives, - Loc: loc(parser, start), + return ast.NewInlineFragment(&ast.InlineFragment{ + TypeCondition: typeCondition, + Directives: directives, + SelectionSet: selectionSet, + Loc: loc(parser, start), }), nil } +/** + * FragmentDefinition : + * - fragment FragmentName on TypeCondition Directives? SelectionSet + * + * TypeCondition : NamedType + */ func parseFragmentDefinition(parser *Parser) (*ast.FragmentDefinition, error) { start := parser.Token.Start _, err := expectKeyWord(parser, "fragment") @@ -500,6 +563,9 @@ func parseFragmentDefinition(parser *Parser) (*ast.FragmentDefinition, error) { }), nil } +/** + * FragmentName : Name but not `on` + */ func parseFragmentName(parser *Parser) (*ast.Name, error) { if parser.Token.Value == "on" { return nil, unexpected(parser, lexer.Token{}) @@ -509,6 +575,21 @@ func parseFragmentName(parser *Parser) (*ast.Name, error) { /* Implements the parsing rules in the Values section. */ +/** + * Value[Const] : + * - [~Const] Variable + * - IntValue + * - FloatValue + * - StringValue + * - BooleanValue + * - EnumValue + * - ListValue[?Const] + * - ObjectValue[?Const] + * + * BooleanValue : one of `true` `false` + * + * EnumValue : Name but not `true`, `false` or `null` + */ func parseValueLiteral(parser *Parser, isConst bool) (ast.Value, error) { token := parser.Token switch token.Kind { @@ -585,6 +666,11 @@ func parseValueValue(parser *Parser) (interface{}, error) { return parseValueLiteral(parser, false) } +/** + * ListValue[Const] : + * - [ ] + * - [ Value[?Const]+ ] + */ func parseList(parser *Parser, isConst bool) (*ast.ListValue, error) { start := parser.Token.Start var item parseFn @@ -607,6 +693,11 @@ func parseList(parser *Parser, isConst bool) (*ast.ListValue, error) { }), nil } +/** + * ObjectValue[Const] : + * - { } + * - { ObjectField[?Const]+ } + */ func parseObject(parser *Parser, isConst bool) (*ast.ObjectValue, error) { start := parser.Token.Start _, err := expect(parser, lexer.TokenKind[lexer.BRACE_L]) @@ -614,18 +705,16 @@ func parseObject(parser *Parser, isConst bool) (*ast.ObjectValue, error) { return nil, err } fields := []*ast.ObjectField{} - fieldNames := map[string]bool{} for { if skp, err := skip(parser, lexer.TokenKind[lexer.BRACE_R]); err != nil { return nil, err } else if skp { break } - field, fieldName, err := parseObjectField(parser, isConst, fieldNames) + field, err := parseObjectField(parser, isConst) if err != nil { return nil, err } - fieldNames[fieldName] = true fields = append(fields, field) } return ast.NewObjectValue(&ast.ObjectValue{ @@ -634,34 +723,35 @@ func parseObject(parser *Parser, isConst bool) (*ast.ObjectValue, error) { }), nil } -func parseObjectField(parser *Parser, isConst bool, fieldNames map[string]bool) (*ast.ObjectField, string, error) { +/** + * ObjectField[Const] : Name : Value[?Const] + */ +func parseObjectField(parser *Parser, isConst bool) (*ast.ObjectField, error) { start := parser.Token.Start name, err := parseName(parser) if err != nil { - return nil, "", err - } - fieldName := name.Value - if _, ok := fieldNames[fieldName]; ok { - descp := fmt.Sprintf("Duplicate input object field %v.", fieldName) - return nil, "", gqlerrors.NewSyntaxError(parser.Source, start, descp) + return nil, err } _, err = expect(parser, lexer.TokenKind[lexer.COLON]) if err != nil { - return nil, "", err + return nil, err } value, err := parseValueLiteral(parser, isConst) if err != nil { - return nil, "", err + return nil, err } return ast.NewObjectField(&ast.ObjectField{ Name: name, Value: value, Loc: loc(parser, start), - }), fieldName, nil + }), nil } /* Implements the parsing rules in the Directives section. */ +/** + * Directives : Directive+ + */ func parseDirectives(parser *Parser) ([]*ast.Directive, error) { directives := []*ast.Directive{} for { @@ -677,6 +767,9 @@ func parseDirectives(parser *Parser) ([]*ast.Directive, error) { return directives, nil } +/** + * Directive : @ Name Arguments? + */ func parseDirective(parser *Parser) (*ast.Directive, error) { start := parser.Token.Start _, err := expect(parser, lexer.TokenKind[lexer.AT]) @@ -700,6 +793,12 @@ func parseDirective(parser *Parser) (*ast.Directive, error) { /* Implements the parsing rules in the Types section. */ +/** + * Type : + * - NamedType + * - ListType + * - NonNullType + */ func parseType(parser *Parser) (ast.Type, error) { start := parser.Token.Start var ttype ast.Type @@ -738,6 +837,9 @@ func parseType(parser *Parser) (ast.Type, error) { return ttype, nil } +/** + * NamedType : Name + */ func parseNamed(parser *Parser) (*ast.Named, error) { start := parser.Token.Start name, err := parseName(parser) @@ -752,6 +854,9 @@ func parseNamed(parser *Parser) (*ast.Named, error) { /* Implements the parsing rules in the Type Definition section. */ +/** + * ObjectTypeDefinition : type Name ImplementsInterfaces? { FieldDefinition+ } + */ func parseObjectTypeDefinition(parser *Parser) (*ast.ObjectDefinition, error) { start := parser.Token.Start _, err := expectKeyWord(parser, "type") @@ -784,6 +889,9 @@ func parseObjectTypeDefinition(parser *Parser) (*ast.ObjectDefinition, error) { }), nil } +/** + * ImplementsInterfaces : implements NamedType+ + */ func parseImplementsInterfaces(parser *Parser) ([]*ast.Named, error) { types := []*ast.Named{} if parser.Token.Value == "implements" { @@ -804,6 +912,9 @@ func parseImplementsInterfaces(parser *Parser) ([]*ast.Named, error) { return types, nil } +/** + * FieldDefinition : Name ArgumentsDefinition? : Type + */ func parseFieldDefinition(parser *Parser) (interface{}, error) { start := parser.Token.Start name, err := parseName(parser) @@ -830,6 +941,9 @@ func parseFieldDefinition(parser *Parser) (interface{}, error) { }), nil } +/** + * ArgumentsDefinition : ( InputValueDefinition+ ) + */ func parseArgumentDefs(parser *Parser) ([]*ast.InputValueDefinition, error) { inputValueDefinitions := []*ast.InputValueDefinition{} @@ -848,6 +962,9 @@ func parseArgumentDefs(parser *Parser) ([]*ast.InputValueDefinition, error) { return inputValueDefinitions, err } +/** + * InputValueDefinition : Name : Type DefaultValue? + */ func parseInputValueDef(parser *Parser) (interface{}, error) { start := parser.Token.Start name, err := parseName(parser) @@ -882,6 +999,9 @@ func parseInputValueDef(parser *Parser) (interface{}, error) { }), nil } +/** + * InterfaceTypeDefinition : interface Name { FieldDefinition+ } + */ func parseInterfaceTypeDefinition(parser *Parser) (*ast.InterfaceDefinition, error) { start := parser.Token.Start _, err := expectKeyWord(parser, "interface") @@ -909,6 +1029,9 @@ func parseInterfaceTypeDefinition(parser *Parser) (*ast.InterfaceDefinition, err }), nil } +/** + * UnionTypeDefinition : union Name = UnionMembers + */ func parseUnionTypeDefinition(parser *Parser) (*ast.UnionDefinition, error) { start := parser.Token.Start _, err := expectKeyWord(parser, "union") @@ -934,6 +1057,11 @@ func parseUnionTypeDefinition(parser *Parser) (*ast.UnionDefinition, error) { }), nil } +/** + * UnionMembers : + * - NamedType + * - UnionMembers | NamedType + */ func parseUnionMembers(parser *Parser) ([]*ast.Named, error) { members := []*ast.Named{} for { @@ -951,6 +1079,9 @@ func parseUnionMembers(parser *Parser) ([]*ast.Named, error) { return members, nil } +/** + * ScalarTypeDefinition : scalar Name + */ func parseScalarTypeDefinition(parser *Parser) (*ast.ScalarDefinition, error) { start := parser.Token.Start _, err := expectKeyWord(parser, "scalar") @@ -968,6 +1099,9 @@ func parseScalarTypeDefinition(parser *Parser) (*ast.ScalarDefinition, error) { return def, nil } +/** + * EnumTypeDefinition : enum Name { EnumValueDefinition+ } + */ func parseEnumTypeDefinition(parser *Parser) (*ast.EnumDefinition, error) { start := parser.Token.Start _, err := expectKeyWord(parser, "enum") @@ -995,6 +1129,11 @@ func parseEnumTypeDefinition(parser *Parser) (*ast.EnumDefinition, error) { }), nil } +/** + * EnumValueDefinition : EnumValue + * + * EnumValue : Name + */ func parseEnumValueDefinition(parser *Parser) (interface{}, error) { start := parser.Token.Start name, err := parseName(parser) @@ -1007,6 +1146,9 @@ func parseEnumValueDefinition(parser *Parser) (interface{}, error) { }), nil } +/** + * InputObjectTypeDefinition : input Name { InputValueDefinition+ } + */ func parseInputObjectTypeDefinition(parser *Parser) (*ast.InputObjectDefinition, error) { start := parser.Token.Start _, err := expectKeyWord(parser, "input") @@ -1034,6 +1176,9 @@ func parseInputObjectTypeDefinition(parser *Parser) (*ast.InputObjectDefinition, }), nil } +/** + * TypeExtensionDefinition : extend ObjectTypeDefinition + */ func parseTypeExtensionDefinition(parser *Parser) (*ast.TypeExtensionDefinition, error) { start := parser.Token.Start _, err := expectKeyWord(parser, "extend") @@ -1095,13 +1240,12 @@ func skip(parser *Parser, Kind int) (bool, error) { if parser.Token.Kind == Kind { err := advance(parser) return true, err - } else { - return false, nil } + return false, nil } // If the next token is of the given kind, return that token after advancing -// the parser. Otherwise, do not change the parser state and return false. +// the parser. Otherwise, do not change the parser state and return error. func expect(parser *Parser, kind int) (lexer.Token, error) { token := parser.Token if token.Kind == kind { diff --git a/language/parser/parser_test.go b/language/parser/parser_test.go index b89d21a7..f8697310 100644 --- a/language/parser/parser_test.go +++ b/language/parser/parser_test.go @@ -10,6 +10,7 @@ import ( "github.com/graphql-go/graphql/gqlerrors" "github.com/graphql-go/graphql/language/ast" "github.com/graphql-go/graphql/language/location" + "github.com/graphql-go/graphql/language/printer" "github.com/graphql-go/graphql/language/source" ) @@ -137,7 +138,7 @@ fragment MissingOn Type func TestParseProvidesUsefulErrorsWhenUsingSource(t *testing.T) { test := errorMessageTest{ source.NewSource(&source.Source{Body: "query", Name: "MyQuery.graphql"}), - `Syntax Error MyQuery.graphql (1:6) Expected Name, found EOF`, + `Syntax Error MyQuery.graphql (1:6) Expected {, found EOF`, false, } testErrorMessage(t, test) @@ -161,15 +162,6 @@ func TestParsesConstantDefaultValues(t *testing.T) { testErrorMessage(t, test) } -func TestDuplicatedKeysInInputObject(t *testing.T) { - test := errorMessageTest{ - `{ field(arg: { a: 1, a: 2 }) }'`, - `Syntax Error GraphQL (1:22) Duplicate input object field a.`, - false, - } - testErrorMessage(t, test) -} - func TestDoesNotAcceptFragmentsNameOn(t *testing.T) { test := errorMessageTest{ `fragment on on on { on }`, @@ -197,6 +189,83 @@ func TestDoesNotAllowNullAsValue(t *testing.T) { testErrorMessage(t, test) } +func TestParsesMultiByteCharacters(t *testing.T) { + + doc := ` + # This comment has a \u0A0A multi-byte character. + { field(arg: "Has a \u0A0A multi-byte character.") } + ` + astDoc := parse(t, doc) + + expectedASTDoc := ast.NewDocument(&ast.Document{ + Loc: ast.NewLocation(&ast.Location{ + Start: 67, + End: 121, + }), + Definitions: []ast.Node{ + ast.NewOperationDefinition(&ast.OperationDefinition{ + Loc: ast.NewLocation(&ast.Location{ + Start: 67, + End: 119, + }), + Operation: "query", + SelectionSet: ast.NewSelectionSet(&ast.SelectionSet{ + Loc: ast.NewLocation(&ast.Location{ + Start: 67, + End: 119, + }), + Selections: []ast.Selection{ + ast.NewField(&ast.Field{ + Loc: ast.NewLocation(&ast.Location{ + Start: 67, + End: 117, + }), + Name: ast.NewName(&ast.Name{ + Loc: ast.NewLocation(&ast.Location{ + Start: 69, + End: 74, + }), + Value: "field", + }), + Arguments: []*ast.Argument{ + ast.NewArgument(&ast.Argument{ + Loc: ast.NewLocation(&ast.Location{ + Start: 75, + End: 116, + }), + Name: ast.NewName(&ast.Name{ + + Loc: ast.NewLocation(&ast.Location{ + Start: 75, + End: 78, + }), + Value: "arg", + }), + Value: ast.NewStringValue(&ast.StringValue{ + + Loc: ast.NewLocation(&ast.Location{ + Start: 80, + End: 116, + }), + Value: "Has a \u0A0A multi-byte character.", + }), + }), + }, + }), + }, + }), + }), + }, + }) + + astDocQuery := printer.Print(astDoc) + expectedASTDocQuery := printer.Print(expectedASTDoc) + + if !reflect.DeepEqual(astDocQuery, expectedASTDocQuery) { + t.Fatalf("unexpected document, expected: %v, got: %v", astDocQuery, expectedASTDocQuery) + } +} + func TestParsesKitchenSink(t *testing.T) { b, err := ioutil.ReadFile("../../kitchen-sink.graphql") if err != nil { @@ -215,6 +284,7 @@ func TestAllowsNonKeywordsAnywhereNameIsAllowed(t *testing.T) { "fragment", "query", "mutation", + "subscription", "true", "false", } @@ -239,7 +309,56 @@ func TestAllowsNonKeywordsAnywhereNameIsAllowed(t *testing.T) { } } -func TestParsesExperimentalSubscriptionFeature(t *testing.T) { +// +//func TestParsesExperimentalSubscriptionFeature(t *testing.T) { +// source := ` +// subscription Foo { +// subscriptionField +// } +// ` +// _, err := Parse(ParseParams{Source: source}) +// if err != nil { +// t.Fatalf("unexpected error: %v", err) +// } +//} + +func TestParsesAnonymousMutationOperations(t *testing.T) { + source := ` + mutation { + mutationField + } + ` + _, err := Parse(ParseParams{Source: source}) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } +} + +func TestParsesAnonymousSubscriptionOperations(t *testing.T) { + source := ` + subscription { + subscriptionField + } + ` + _, err := Parse(ParseParams{Source: source}) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } +} + +func TestParsesNamedMutationOperations(t *testing.T) { + source := ` + mutation Foo { + mutationField + } + ` + _, err := Parse(ParseParams{Source: source}) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } +} + +func TestParsesNamedSubscriptionOperations(t *testing.T) { source := ` subscription Foo { subscriptionField diff --git a/language/printer/printer.go b/language/printer/printer.go index 8d41b672..4cca90e4 100644 --- a/language/printer/printer.go +++ b/language/printer/printer.go @@ -147,35 +147,38 @@ var printDocASTReducer = map[string]visitor.VisitFunc{ op := node.Operation name := fmt.Sprintf("%v", node.Name) - defs := wrap("(", join(toSliceString(node.VariableDefinitions), ", "), ")") + varDefs := wrap("(", join(toSliceString(node.VariableDefinitions), ", "), ")") directives := join(toSliceString(node.Directives), " ") selectionSet := fmt.Sprintf("%v", node.SelectionSet) + // Anonymous queries with no directives or variable definitions can use + // the query short form. str := "" - if name == "" { + if name == "" && directives == "" && varDefs == "" && op == "query" { str = selectionSet } else { str = join([]string{ op, - join([]string{name, defs}, ""), + join([]string{name, varDefs}, ""), directives, selectionSet, }, " ") } return visitor.ActionUpdate, str case map[string]interface{}: + op := getMapValueString(node, "Operation") name := getMapValueString(node, "Name") - defs := wrap("(", join(toSliceString(getMapValue(node, "VariableDefinitions")), ", "), ")") + varDefs := wrap("(", join(toSliceString(getMapValue(node, "VariableDefinitions")), ", "), ")") directives := join(toSliceString(getMapValue(node, "Directives")), " ") selectionSet := getMapValueString(node, "SelectionSet") str := "" - if name == "" { + if name == "" && directives == "" && varDefs == "" && op == "query" { str = selectionSet } else { str = join([]string{ op, - join([]string{name, defs}, ""), + join([]string{name, varDefs}, ""), directives, selectionSet, }, " ") @@ -275,7 +278,13 @@ var printDocASTReducer = map[string]visitor.VisitFunc{ typeCondition := getMapValueString(node, "TypeCondition") directives := toSliceString(getMapValue(node, "Directives")) selectionSet := getMapValueString(node, "SelectionSet") - return visitor.ActionUpdate, "... on " + typeCondition + " " + wrap("", join(directives, " "), " ") + selectionSet + return visitor.ActionUpdate, + join([]string{ + "...", + wrap("on ", typeCondition, ""), + join(directives, " "), + selectionSet, + }, " ") } return visitor.ActionNoChange, nil }, diff --git a/language/printer/printer_test.go b/language/printer/printer_test.go index 61d3dca1..760ec65a 100644 --- a/language/printer/printer_test.go +++ b/language/printer/printer_test.go @@ -59,6 +59,68 @@ func TestPrinter_PrintsMinimalAST(t *testing.T) { } } +// TestPrinter_ProducesHelpfulErrorMessages +// Skipped, can't figure out how to pass in an invalid astDoc, which is already strongly-typed + +func TestPrinter_CorrectlyPrintsNonQueryOperationsWithoutName(t *testing.T) { + + // Test #1 + queryAstShorthanded := `query { id, name }` + expected := `{ + id + name +} +` + astDoc := parse(t, queryAstShorthanded) + results := printer.Print(astDoc) + + if !reflect.DeepEqual(expected, results) { + t.Fatalf("Unexpected result, Diff: %v", testutil.Diff(results, expected)) + } + + // Test #2 + mutationAst := `mutation { id, name }` + expected = `mutation { + id + name +} +` + astDoc = parse(t, mutationAst) + results = printer.Print(astDoc) + + if !reflect.DeepEqual(expected, results) { + t.Fatalf("Unexpected result, Diff: %v", testutil.Diff(results, expected)) + } + + // Test #3 + queryAstWithArtifacts := `query ($foo: TestType) @testDirective { id, name }` + expected = `query ($foo: TestType) @testDirective { + id + name +} +` + astDoc = parse(t, queryAstWithArtifacts) + results = printer.Print(astDoc) + + if !reflect.DeepEqual(expected, results) { + t.Fatalf("Unexpected result, Diff: %v", testutil.Diff(results, expected)) + } + + // Test #4 + mutationAstWithArtifacts := `mutation ($foo: TestType) @testDirective { id, name }` + expected = `mutation ($foo: TestType) @testDirective { + id + name +} +` + astDoc = parse(t, mutationAstWithArtifacts) + results = printer.Print(astDoc) + + if !reflect.DeepEqual(expected, results) { + t.Fatalf("Unexpected result, Diff: %v", testutil.Diff(results, expected)) + } +} + func TestPrinter_PrintsKitchenSink(t *testing.T) { b, err := ioutil.ReadFile("../../kitchen-sink.graphql") if err != nil { @@ -79,6 +141,12 @@ func TestPrinter_PrintsKitchenSink(t *testing.T) { } } } + ... @skip(unless: $foo) { + id + } + ... { + id + } } } @@ -90,6 +158,19 @@ mutation favPost { } } +subscription PostFavSubscription($input: StoryLikeSubscribeInput) { + postFavSubscribe(input: $input) { + post { + favers { + count + } + favSentence { + text + } + } + } +} + fragment frag on Follower { foo(size: $size, bar: $b, obj: {key: "value"}) } diff --git a/language/typeInfo/type_info.go b/language/typeInfo/type_info.go new file mode 100644 index 00000000..e012ee02 --- /dev/null +++ b/language/typeInfo/type_info.go @@ -0,0 +1,11 @@ +package typeInfo + +import ( + "github.com/graphql-go/graphql/language/ast" +) + +// TypeInfoI defines the interface for TypeInfo Implementation +type TypeInfoI interface { + Enter(node ast.Node) + Leave(node ast.Node) +} diff --git a/language/visitor/visitor.go b/language/visitor/visitor.go index 83edbd9b..62c50149 100644 --- a/language/visitor/visitor.go +++ b/language/visitor/visitor.go @@ -4,6 +4,7 @@ import ( "encoding/json" "fmt" "github.com/graphql-go/graphql/language/ast" + "github.com/graphql-go/graphql/language/typeInfo" "reflect" ) @@ -17,7 +18,7 @@ const ( type KeyMap map[string][]string // note that the keys are in Capital letters, equivalent to the ast.Node field Names -var QueryDocumentKeys KeyMap = KeyMap{ +var QueryDocumentKeys = KeyMap{ "Name": []string{}, "Document": []string{"Definitions"}, "OperationDefinition": []string{ @@ -380,7 +381,7 @@ Loop: kind = node.GetKind() } - visitFn := GetVisitFn(visitorOpts, isLeaving, kind) + visitFn := GetVisitFn(visitorOpts, kind, isLeaving) if visitFn != nil { p := VisitFuncParams{ Node: nodeIn, @@ -489,7 +490,7 @@ Loop: } } if len(edits) != 0 { - result = edits[0].Value + result = edits[len(edits)-1].Value } return result } @@ -620,10 +621,9 @@ func updateNodeField(value interface{}, fieldName string, fieldValue interface{} if isPtr == true { retVal = val.Addr().Interface() return retVal - } else { - retVal = val.Interface() - return retVal } + retVal = val.Interface() + return retVal } } @@ -709,7 +709,103 @@ func isNilNode(node interface{}) bool { return val.Interface() == nil } -func GetVisitFn(visitorOpts *VisitorOptions, isLeaving bool, kind string) VisitFunc { +// VisitInParallel Creates a new visitor instance which delegates to many visitors to run in +// parallel. Each visitor will be visited for each node before moving on. +// +// If a prior visitor edits a node, no following visitors will see that node. +func VisitInParallel(visitorOptsSlice ...*VisitorOptions) *VisitorOptions { + skipping := map[int]interface{}{} + + return &VisitorOptions{ + Enter: func(p VisitFuncParams) (string, interface{}) { + for i, visitorOpts := range visitorOptsSlice { + if _, ok := skipping[i]; !ok { + switch node := p.Node.(type) { + case ast.Node: + kind := node.GetKind() + fn := GetVisitFn(visitorOpts, kind, false) + if fn != nil { + action, result := fn(p) + if action == ActionSkip { + skipping[i] = node + } else if action == ActionBreak { + skipping[i] = ActionBreak + } else if action == ActionUpdate { + return ActionUpdate, result + } + } + } + } + } + return ActionNoChange, nil + }, + Leave: func(p VisitFuncParams) (string, interface{}) { + for i, visitorOpts := range visitorOptsSlice { + skippedNode, ok := skipping[i] + if !ok { + switch node := p.Node.(type) { + case ast.Node: + kind := node.GetKind() + fn := GetVisitFn(visitorOpts, kind, true) + if fn != nil { + action, result := fn(p) + if action == ActionBreak { + skipping[i] = ActionBreak + } else if action == ActionUpdate { + return ActionUpdate, result + } + } + } + } else if skippedNode == p.Node { + delete(skipping, i) + } + } + return ActionNoChange, nil + }, + } +} + +// VisitWithTypeInfo Creates a new visitor instance which maintains a provided TypeInfo instance +// along with visiting visitor. +func VisitWithTypeInfo(ttypeInfo typeInfo.TypeInfoI, visitorOpts *VisitorOptions) *VisitorOptions { + return &VisitorOptions{ + Enter: func(p VisitFuncParams) (string, interface{}) { + if node, ok := p.Node.(ast.Node); ok { + ttypeInfo.Enter(node) + fn := GetVisitFn(visitorOpts, node.GetKind(), false) + if fn != nil { + action, result := fn(p) + if action == ActionUpdate { + ttypeInfo.Leave(node) + if isNode(result) { + if result, ok := result.(ast.Node); ok { + ttypeInfo.Enter(result) + } + } + } + return action, result + } + } + return ActionNoChange, nil + }, + Leave: func(p VisitFuncParams) (string, interface{}) { + action := ActionNoChange + var result interface{} + if node, ok := p.Node.(ast.Node); ok { + fn := GetVisitFn(visitorOpts, node.GetKind(), true) + if fn != nil { + action, result = fn(p) + } + ttypeInfo.Leave(node) + } + return action, result + }, + } +} + +// GetVisitFn Given a visitor instance, if it is leaving or not, and a node kind, return +// the function the visitor runtime should call. +func GetVisitFn(visitorOpts *VisitorOptions, kind string, isLeaving bool) VisitFunc { if visitorOpts == nil { return nil } @@ -722,12 +818,11 @@ func GetVisitFn(visitorOpts *VisitorOptions, isLeaving bool, kind string) VisitF if isLeaving { // { Kind: { leave() {} } } return kindVisitor.Leave - } else { - // { Kind: { enter() {} } } - return kindVisitor.Enter } - } + // { Kind: { enter() {} } } + return kindVisitor.Enter + } if isLeaving { // { enter() {} } specificVisitor := visitorOpts.Leave @@ -739,17 +834,15 @@ func GetVisitFn(visitorOpts *VisitorOptions, isLeaving bool, kind string) VisitF return specificKindVisitor } - } else { - // { leave() {} } - specificVisitor := visitorOpts.Enter - if specificVisitor != nil { - return specificVisitor - } - if specificKindVisitor, ok := visitorOpts.EnterKindMap[kind]; ok { - // { enter: { Kind() {} } } - return specificKindVisitor - } } - + // { leave() {} } + specificVisitor := visitorOpts.Enter + if specificVisitor != nil { + return specificVisitor + } + if specificKindVisitor, ok := visitorOpts.EnterKindMap[kind]; ok { + // { enter: { Kind() {} } } + return specificKindVisitor + } return nil } diff --git a/language/visitor/visitor_test.go b/language/visitor/visitor_test.go index 412f96c0..b0770125 100644 --- a/language/visitor/visitor_test.go +++ b/language/visitor/visitor_test.go @@ -5,8 +5,12 @@ import ( "reflect" "testing" + "fmt" + "github.com/graphql-go/graphql" "github.com/graphql-go/graphql/language/ast" + "github.com/graphql-go/graphql/language/kinds" "github.com/graphql-go/graphql/language/parser" + "github.com/graphql-go/graphql/language/printer" "github.com/graphql-go/graphql/language/visitor" "github.com/graphql-go/graphql/testutil" ) @@ -24,6 +28,132 @@ func parse(t *testing.T, query string) *ast.Document { return astDoc } +func TestVisitor_AllowsEditingANodeBothOnEnterAndOnLeave(t *testing.T) { + + query := `{ a, b, c { a, b, c } }` + astDoc := parse(t, query) + + var selectionSet *ast.SelectionSet + + expectedQuery := `{ a, b, c { a, b, c } }` + expectedAST := parse(t, expectedQuery) + + visited := map[string]bool{ + "didEnter": false, + "didLeave": false, + } + + expectedVisited := map[string]bool{ + "didEnter": true, + "didLeave": true, + } + + v := &visitor.VisitorOptions{ + + KindFuncMap: map[string]visitor.NamedVisitFuncs{ + kinds.OperationDefinition: visitor.NamedVisitFuncs{ + Enter: func(p visitor.VisitFuncParams) (string, interface{}) { + if node, ok := p.Node.(*ast.OperationDefinition); ok { + selectionSet = node.SelectionSet + visited["didEnter"] = true + return visitor.ActionUpdate, ast.NewOperationDefinition(&ast.OperationDefinition{ + Loc: node.Loc, + Operation: node.Operation, + Name: node.Name, + VariableDefinitions: node.VariableDefinitions, + Directives: node.Directives, + SelectionSet: ast.NewSelectionSet(&ast.SelectionSet{ + Selections: []ast.Selection{}, + }), + }) + } + return visitor.ActionNoChange, nil + }, + Leave: func(p visitor.VisitFuncParams) (string, interface{}) { + if node, ok := p.Node.(*ast.OperationDefinition); ok { + visited["didLeave"] = true + return visitor.ActionUpdate, ast.NewOperationDefinition(&ast.OperationDefinition{ + Loc: node.Loc, + Operation: node.Operation, + Name: node.Name, + VariableDefinitions: node.VariableDefinitions, + Directives: node.Directives, + SelectionSet: selectionSet, + }) + } + return visitor.ActionNoChange, nil + }, + }, + }, + } + + editedAst := visitor.Visit(astDoc, v, nil) + if !reflect.DeepEqual(expectedAST, editedAst) { + t.Fatalf("Unexpected result, Diff: %v", testutil.Diff(expectedAST, editedAst)) + } + + if !reflect.DeepEqual(visited, expectedVisited) { + t.Fatalf("Unexpected result, Diff: %v", testutil.Diff(visited, expectedVisited)) + } + +} +func TestVisitor_AllowsEditingTheRootNodeOnEnterAndOnLeave(t *testing.T) { + + query := `{ a, b, c { a, b, c } }` + astDoc := parse(t, query) + + definitions := astDoc.Definitions + + expectedQuery := `{ a, b, c { a, b, c } }` + expectedAST := parse(t, expectedQuery) + + visited := map[string]bool{ + "didEnter": false, + "didLeave": false, + } + + expectedVisited := map[string]bool{ + "didEnter": true, + "didLeave": true, + } + + v := &visitor.VisitorOptions{ + + KindFuncMap: map[string]visitor.NamedVisitFuncs{ + kinds.Document: visitor.NamedVisitFuncs{ + Enter: func(p visitor.VisitFuncParams) (string, interface{}) { + if node, ok := p.Node.(*ast.Document); ok { + visited["didEnter"] = true + return visitor.ActionUpdate, ast.NewDocument(&ast.Document{ + Loc: node.Loc, + Definitions: []ast.Node{}, + }) + } + return visitor.ActionNoChange, nil + }, + Leave: func(p visitor.VisitFuncParams) (string, interface{}) { + if node, ok := p.Node.(*ast.Document); ok { + visited["didLeave"] = true + return visitor.ActionUpdate, ast.NewDocument(&ast.Document{ + Loc: node.Loc, + Definitions: definitions, + }) + } + return visitor.ActionNoChange, nil + }, + }, + }, + } + + editedAst := visitor.Visit(astDoc, v, nil) + if !reflect.DeepEqual(expectedAST, editedAst) { + t.Fatalf("Unexpected result, Diff: %v", testutil.Diff(expectedAST, editedAst)) + } + + if !reflect.DeepEqual(visited, expectedVisited) { + t.Fatalf("Unexpected result, Diff: %v", testutil.Diff(visited, expectedVisited)) + } +} func TestVisitor_AllowsForEditingOnEnter(t *testing.T) { query := `{ a, b, c { a, b, c } }` @@ -97,10 +227,10 @@ func TestVisitor_VisitsEditedNode(t *testing.T) { s = append(s, addedField) ss := node.SelectionSet ss.Selections = s - return visitor.ActionUpdate, &ast.Field{ + return visitor.ActionUpdate, ast.NewField(&ast.Field{ Kind: "Field", SelectionSet: ss, - } + }) } if reflect.DeepEqual(node, addedField) { didVisitAddedField = true @@ -234,6 +364,65 @@ func TestVisitor_AllowsEarlyExitWhileVisiting(t *testing.T) { } } +func TestVisitor_AllowsEarlyExitWhileLeaving(t *testing.T) { + + visited := []interface{}{} + + query := `{ a, b { x }, c }` + astDoc := parse(t, query) + + expectedVisited := []interface{}{ + []interface{}{"enter", "Document", nil}, + []interface{}{"enter", "OperationDefinition", nil}, + []interface{}{"enter", "SelectionSet", nil}, + []interface{}{"enter", "Field", nil}, + []interface{}{"enter", "Name", "a"}, + []interface{}{"leave", "Name", "a"}, + []interface{}{"leave", "Field", nil}, + []interface{}{"enter", "Field", nil}, + []interface{}{"enter", "Name", "b"}, + []interface{}{"leave", "Name", "b"}, + []interface{}{"enter", "SelectionSet", nil}, + []interface{}{"enter", "Field", nil}, + []interface{}{"enter", "Name", "x"}, + []interface{}{"leave", "Name", "x"}, + } + + v := &visitor.VisitorOptions{ + Enter: func(p visitor.VisitFuncParams) (string, interface{}) { + switch node := p.Node.(type) { + case *ast.Name: + visited = append(visited, []interface{}{"enter", node.Kind, node.Value}) + case ast.Node: + visited = append(visited, []interface{}{"enter", node.GetKind(), nil}) + default: + visited = append(visited, []interface{}{"enter", nil, nil}) + } + return visitor.ActionNoChange, nil + }, + Leave: func(p visitor.VisitFuncParams) (string, interface{}) { + switch node := p.Node.(type) { + case *ast.Name: + visited = append(visited, []interface{}{"leave", node.Kind, node.Value}) + if node.Value == "x" { + return visitor.ActionBreak, nil + } + case ast.Node: + visited = append(visited, []interface{}{"leave", node.GetKind(), nil}) + default: + visited = append(visited, []interface{}{"leave", nil, nil}) + } + return visitor.ActionNoChange, nil + }, + } + + _ = visitor.Visit(astDoc, v, nil) + + if !reflect.DeepEqual(visited, expectedVisited) { + t.Fatalf("Unexpected result, Diff: %v", testutil.Diff(expectedVisited, visited)) + } +} + func TestVisitor_AllowsANamedFunctionsVisitorAPI(t *testing.T) { query := `{ a, b { x }, c }` @@ -355,6 +544,7 @@ func TestVisitor_VisitsKitchenSink(t *testing.T) { []interface{}{"leave", "Name", "Name", "Directive"}, []interface{}{"leave", "Directive", 0, nil}, []interface{}{"enter", "SelectionSet", "SelectionSet", "InlineFragment"}, + []interface{}{"enter", "Field", 0, nil}, []interface{}{"enter", "Name", "Name", "Field"}, []interface{}{"leave", "Name", "Name", "Field"}, @@ -409,6 +599,35 @@ func TestVisitor_VisitsKitchenSink(t *testing.T) { []interface{}{"leave", "Field", 0, nil}, []interface{}{"leave", "SelectionSet", "SelectionSet", "InlineFragment"}, []interface{}{"leave", "InlineFragment", 1, nil}, + []interface{}{"enter", "InlineFragment", 2, nil}, + []interface{}{"enter", "Directive", 0, nil}, + []interface{}{"enter", "Name", "Name", "Directive"}, + []interface{}{"leave", "Name", "Name", "Directive"}, + []interface{}{"enter", "Argument", 0, nil}, + []interface{}{"enter", "Name", "Name", "Argument"}, + []interface{}{"leave", "Name", "Name", "Argument"}, + []interface{}{"enter", "Variable", "Value", "Argument"}, + []interface{}{"enter", "Name", "Name", "Variable"}, + []interface{}{"leave", "Name", "Name", "Variable"}, + + []interface{}{"leave", "Variable", "Value", "Argument"}, + []interface{}{"leave", "Argument", 0, nil}, + []interface{}{"leave", "Directive", 0, nil}, + []interface{}{"enter", "SelectionSet", "SelectionSet", "InlineFragment"}, + []interface{}{"enter", "Field", 0, nil}, + []interface{}{"enter", "Name", "Name", "Field"}, + []interface{}{"leave", "Name", "Name", "Field"}, + []interface{}{"leave", "Field", 0, nil}, + []interface{}{"leave", "SelectionSet", "SelectionSet", "InlineFragment"}, + []interface{}{"leave", "InlineFragment", 2, nil}, + []interface{}{"enter", "InlineFragment", 3, nil}, + []interface{}{"enter", "SelectionSet", "SelectionSet", "InlineFragment"}, + []interface{}{"enter", "Field", 0, nil}, + []interface{}{"enter", "Name", "Name", "Field"}, + []interface{}{"leave", "Name", "Name", "Field"}, + []interface{}{"leave", "Field", 0, nil}, + []interface{}{"leave", "SelectionSet", "SelectionSet", "InlineFragment"}, + []interface{}{"leave", "InlineFragment", 3, nil}, []interface{}{"leave", "SelectionSet", "SelectionSet", "Field"}, []interface{}{"leave", "Field", 0, nil}, []interface{}{"leave", "SelectionSet", "SelectionSet", "OperationDefinition"}, @@ -445,7 +664,64 @@ func TestVisitor_VisitsKitchenSink(t *testing.T) { []interface{}{"leave", "Field", 0, nil}, []interface{}{"leave", "SelectionSet", "SelectionSet", "OperationDefinition"}, []interface{}{"leave", "OperationDefinition", 1, nil}, - []interface{}{"enter", "FragmentDefinition", 2, nil}, + []interface{}{"enter", "OperationDefinition", 2, nil}, + []interface{}{"enter", "Name", "Name", "OperationDefinition"}, + []interface{}{"leave", "Name", "Name", "OperationDefinition"}, + []interface{}{"enter", "VariableDefinition", 0, nil}, + []interface{}{"enter", "Variable", "Variable", "VariableDefinition"}, + []interface{}{"enter", "Name", "Name", "Variable"}, + []interface{}{"leave", "Name", "Name", "Variable"}, + + []interface{}{"leave", "Variable", "Variable", "VariableDefinition"}, + []interface{}{"enter", "Named", "Type", "VariableDefinition"}, + []interface{}{"enter", "Name", "Name", "Named"}, + []interface{}{"leave", "Name", "Name", "Named"}, + []interface{}{"leave", "Named", "Type", "VariableDefinition"}, + []interface{}{"leave", "VariableDefinition", 0, nil}, + []interface{}{"enter", "SelectionSet", "SelectionSet", "OperationDefinition"}, + []interface{}{"enter", "Field", 0, nil}, + []interface{}{"enter", "Name", "Name", "Field"}, + []interface{}{"leave", "Name", "Name", "Field"}, + []interface{}{"enter", "Argument", 0, nil}, + []interface{}{"enter", "Name", "Name", "Argument"}, + []interface{}{"leave", "Name", "Name", "Argument"}, + []interface{}{"enter", "Variable", "Value", "Argument"}, + []interface{}{"enter", "Name", "Name", "Variable"}, + []interface{}{"leave", "Name", "Name", "Variable"}, + []interface{}{"leave", "Variable", "Value", "Argument"}, + []interface{}{"leave", "Argument", 0, nil}, + []interface{}{"enter", "SelectionSet", "SelectionSet", "Field"}, + []interface{}{"enter", "Field", 0, nil}, + []interface{}{"enter", "Name", "Name", "Field"}, + []interface{}{"leave", "Name", "Name", "Field"}, + []interface{}{"enter", "SelectionSet", "SelectionSet", "Field"}, + []interface{}{"enter", "Field", 0, nil}, + []interface{}{"enter", "Name", "Name", "Field"}, + []interface{}{"leave", "Name", "Name", "Field"}, + []interface{}{"enter", "SelectionSet", "SelectionSet", "Field"}, + []interface{}{"enter", "Field", 0, nil}, + []interface{}{"enter", "Name", "Name", "Field"}, + []interface{}{"leave", "Name", "Name", "Field"}, + []interface{}{"leave", "Field", 0, nil}, + []interface{}{"leave", "SelectionSet", "SelectionSet", "Field"}, + []interface{}{"leave", "Field", 0, nil}, + []interface{}{"enter", "Field", 1, nil}, + []interface{}{"enter", "Name", "Name", "Field"}, + []interface{}{"leave", "Name", "Name", "Field"}, + []interface{}{"enter", "SelectionSet", "SelectionSet", "Field"}, + []interface{}{"enter", "Field", 0, nil}, + []interface{}{"enter", "Name", "Name", "Field"}, + []interface{}{"leave", "Name", "Name", "Field"}, + []interface{}{"leave", "Field", 0, nil}, + []interface{}{"leave", "SelectionSet", "SelectionSet", "Field"}, + []interface{}{"leave", "Field", 1, nil}, + []interface{}{"leave", "SelectionSet", "SelectionSet", "Field"}, + []interface{}{"leave", "Field", 0, nil}, + []interface{}{"leave", "SelectionSet", "SelectionSet", "Field"}, + []interface{}{"leave", "Field", 0, nil}, + []interface{}{"leave", "SelectionSet", "SelectionSet", "OperationDefinition"}, + []interface{}{"leave", "OperationDefinition", 2, nil}, + []interface{}{"enter", "FragmentDefinition", 3, nil}, []interface{}{"enter", "Name", "Name", "FragmentDefinition"}, []interface{}{"leave", "Name", "Name", "FragmentDefinition"}, []interface{}{"enter", "Named", "TypeCondition", "FragmentDefinition"}, @@ -459,6 +735,7 @@ func TestVisitor_VisitsKitchenSink(t *testing.T) { []interface{}{"enter", "Argument", 0, nil}, []interface{}{"enter", "Name", "Name", "Argument"}, []interface{}{"leave", "Name", "Name", "Argument"}, + []interface{}{"enter", "Variable", "Value", "Argument"}, []interface{}{"enter", "Name", "Name", "Variable"}, []interface{}{"leave", "Name", "Name", "Variable"}, @@ -486,8 +763,8 @@ func TestVisitor_VisitsKitchenSink(t *testing.T) { []interface{}{"leave", "Argument", 2, nil}, []interface{}{"leave", "Field", 0, nil}, []interface{}{"leave", "SelectionSet", "SelectionSet", "FragmentDefinition"}, - []interface{}{"leave", "FragmentDefinition", 2, nil}, - []interface{}{"enter", "OperationDefinition", 3, nil}, + []interface{}{"leave", "FragmentDefinition", 3, nil}, + []interface{}{"enter", "OperationDefinition", 4, nil}, []interface{}{"enter", "SelectionSet", "SelectionSet", "OperationDefinition"}, []interface{}{"enter", "Field", 0, nil}, []interface{}{"enter", "Name", "Name", "Field"}, @@ -510,7 +787,7 @@ func TestVisitor_VisitsKitchenSink(t *testing.T) { []interface{}{"leave", "Name", "Name", "Field"}, []interface{}{"leave", "Field", 1, nil}, []interface{}{"leave", "SelectionSet", "SelectionSet", "OperationDefinition"}, - []interface{}{"leave", "OperationDefinition", 3, nil}, + []interface{}{"leave", "OperationDefinition", 4, nil}, []interface{}{"leave", "Document", nil, nil}, } @@ -545,3 +822,926 @@ func TestVisitor_VisitsKitchenSink(t *testing.T) { t.Fatalf("Unexpected result, Diff: %v", testutil.Diff(expectedVisited, visited)) } } + +func TestVisitor_VisitInParallel_AllowsSkippingASubTree(t *testing.T) { + + // Note: nearly identical to the above test of the same test but + // using visitInParallel. + + query := `{ a, b { x }, c }` + astDoc := parse(t, query) + + visited := []interface{}{} + expectedVisited := []interface{}{ + []interface{}{"enter", "Document", nil}, + []interface{}{"enter", "OperationDefinition", nil}, + []interface{}{"enter", "SelectionSet", nil}, + []interface{}{"enter", "Field", nil}, + []interface{}{"enter", "Name", "a"}, + []interface{}{"leave", "Name", "a"}, + []interface{}{"leave", "Field", nil}, + []interface{}{"enter", "Field", nil}, + []interface{}{"enter", "Field", nil}, + []interface{}{"enter", "Name", "c"}, + []interface{}{"leave", "Name", "c"}, + []interface{}{"leave", "Field", nil}, + []interface{}{"leave", "SelectionSet", nil}, + []interface{}{"leave", "OperationDefinition", nil}, + []interface{}{"leave", "Document", nil}, + } + + v := &visitor.VisitorOptions{ + Enter: func(p visitor.VisitFuncParams) (string, interface{}) { + switch node := p.Node.(type) { + case *ast.Name: + visited = append(visited, []interface{}{"enter", node.Kind, node.Value}) + case *ast.Field: + visited = append(visited, []interface{}{"enter", node.Kind, nil}) + if node.Name != nil && node.Name.Value == "b" { + return visitor.ActionSkip, nil + } + case ast.Node: + visited = append(visited, []interface{}{"enter", node.GetKind(), nil}) + default: + visited = append(visited, []interface{}{"enter", nil, nil}) + } + return visitor.ActionNoChange, nil + }, + Leave: func(p visitor.VisitFuncParams) (string, interface{}) { + switch node := p.Node.(type) { + case *ast.Name: + visited = append(visited, []interface{}{"leave", node.Kind, node.Value}) + case ast.Node: + visited = append(visited, []interface{}{"leave", node.GetKind(), nil}) + default: + visited = append(visited, []interface{}{"leave", nil, nil}) + } + return visitor.ActionNoChange, nil + }, + } + + _ = visitor.Visit(astDoc, visitor.VisitInParallel(v), nil) + + if !reflect.DeepEqual(visited, expectedVisited) { + t.Fatalf("Unexpected result, Diff: %v", testutil.Diff(expectedVisited, visited)) + } +} + +func TestVisitor_VisitInParallel_AllowsSkippingDifferentSubTrees(t *testing.T) { + + query := `{ a { x }, b { y} }` + astDoc := parse(t, query) + + visited := []interface{}{} + expectedVisited := []interface{}{ + []interface{}{"no-a", "enter", "Document", nil}, + []interface{}{"no-b", "enter", "Document", nil}, + []interface{}{"no-a", "enter", "OperationDefinition", nil}, + []interface{}{"no-b", "enter", "OperationDefinition", nil}, + []interface{}{"no-a", "enter", "SelectionSet", nil}, + []interface{}{"no-b", "enter", "SelectionSet", nil}, + []interface{}{"no-a", "enter", "Field", nil}, + []interface{}{"no-b", "enter", "Field", nil}, + []interface{}{"no-b", "enter", "Name", "a"}, + []interface{}{"no-b", "leave", "Name", "a"}, + []interface{}{"no-b", "enter", "SelectionSet", nil}, + []interface{}{"no-b", "enter", "Field", nil}, + []interface{}{"no-b", "enter", "Name", "x"}, + []interface{}{"no-b", "leave", "Name", "x"}, + []interface{}{"no-b", "leave", "Field", nil}, + []interface{}{"no-b", "leave", "SelectionSet", nil}, + []interface{}{"no-b", "leave", "Field", nil}, + []interface{}{"no-a", "enter", "Field", nil}, + []interface{}{"no-b", "enter", "Field", nil}, + []interface{}{"no-a", "enter", "Name", "b"}, + []interface{}{"no-a", "leave", "Name", "b"}, + []interface{}{"no-a", "enter", "SelectionSet", nil}, + []interface{}{"no-a", "enter", "Field", nil}, + []interface{}{"no-a", "enter", "Name", "y"}, + []interface{}{"no-a", "leave", "Name", "y"}, + []interface{}{"no-a", "leave", "Field", nil}, + []interface{}{"no-a", "leave", "SelectionSet", nil}, + []interface{}{"no-a", "leave", "Field", nil}, + []interface{}{"no-a", "leave", "SelectionSet", nil}, + []interface{}{"no-b", "leave", "SelectionSet", nil}, + []interface{}{"no-a", "leave", "OperationDefinition", nil}, + []interface{}{"no-b", "leave", "OperationDefinition", nil}, + []interface{}{"no-a", "leave", "Document", nil}, + []interface{}{"no-b", "leave", "Document", nil}, + } + + v := []*visitor.VisitorOptions{ + &visitor.VisitorOptions{ + Enter: func(p visitor.VisitFuncParams) (string, interface{}) { + switch node := p.Node.(type) { + case *ast.Name: + visited = append(visited, []interface{}{"no-a", "enter", node.Kind, node.Value}) + case *ast.Field: + visited = append(visited, []interface{}{"no-a", "enter", node.Kind, nil}) + if node.Name != nil && node.Name.Value == "a" { + return visitor.ActionSkip, nil + } + case ast.Node: + visited = append(visited, []interface{}{"no-a", "enter", node.GetKind(), nil}) + default: + visited = append(visited, []interface{}{"no-a", "enter", nil, nil}) + } + return visitor.ActionNoChange, nil + }, + Leave: func(p visitor.VisitFuncParams) (string, interface{}) { + switch node := p.Node.(type) { + case *ast.Name: + visited = append(visited, []interface{}{"no-a", "leave", node.Kind, node.Value}) + case ast.Node: + visited = append(visited, []interface{}{"no-a", "leave", node.GetKind(), nil}) + default: + visited = append(visited, []interface{}{"no-a", "leave", nil, nil}) + } + return visitor.ActionNoChange, nil + }, + }, + &visitor.VisitorOptions{ + Enter: func(p visitor.VisitFuncParams) (string, interface{}) { + switch node := p.Node.(type) { + case *ast.Name: + visited = append(visited, []interface{}{"no-b", "enter", node.Kind, node.Value}) + case *ast.Field: + visited = append(visited, []interface{}{"no-b", "enter", node.Kind, nil}) + if node.Name != nil && node.Name.Value == "b" { + return visitor.ActionSkip, nil + } + case ast.Node: + visited = append(visited, []interface{}{"no-b", "enter", node.GetKind(), nil}) + default: + visited = append(visited, []interface{}{"no-b", "enter", nil, nil}) + } + return visitor.ActionNoChange, nil + }, + Leave: func(p visitor.VisitFuncParams) (string, interface{}) { + switch node := p.Node.(type) { + case *ast.Name: + visited = append(visited, []interface{}{"no-b", "leave", node.Kind, node.Value}) + case ast.Node: + visited = append(visited, []interface{}{"no-b", "leave", node.GetKind(), nil}) + default: + visited = append(visited, []interface{}{"no-b", "leave", nil, nil}) + } + return visitor.ActionNoChange, nil + }, + }, + } + + _ = visitor.Visit(astDoc, visitor.VisitInParallel(v...), nil) + + if !reflect.DeepEqual(visited, expectedVisited) { + t.Fatalf("Unexpected result, Diff: %v", testutil.Diff(expectedVisited, visited)) + } +} + +func TestVisitor_VisitInParallel_AllowsEarlyExitWhileVisiting(t *testing.T) { + + // Note: nearly identical to the above test of the same test but + // using visitInParallel. + + visited := []interface{}{} + + query := `{ a, b { x }, c }` + astDoc := parse(t, query) + + expectedVisited := []interface{}{ + []interface{}{"enter", "Document", nil}, + []interface{}{"enter", "OperationDefinition", nil}, + []interface{}{"enter", "SelectionSet", nil}, + []interface{}{"enter", "Field", nil}, + []interface{}{"enter", "Name", "a"}, + []interface{}{"leave", "Name", "a"}, + []interface{}{"leave", "Field", nil}, + []interface{}{"enter", "Field", nil}, + []interface{}{"enter", "Name", "b"}, + []interface{}{"leave", "Name", "b"}, + []interface{}{"enter", "SelectionSet", nil}, + []interface{}{"enter", "Field", nil}, + []interface{}{"enter", "Name", "x"}, + } + + v := &visitor.VisitorOptions{ + Enter: func(p visitor.VisitFuncParams) (string, interface{}) { + switch node := p.Node.(type) { + case *ast.Name: + visited = append(visited, []interface{}{"enter", node.Kind, node.Value}) + if node.Value == "x" { + return visitor.ActionBreak, nil + } + case ast.Node: + visited = append(visited, []interface{}{"enter", node.GetKind(), nil}) + default: + visited = append(visited, []interface{}{"enter", nil, nil}) + } + return visitor.ActionNoChange, nil + }, + Leave: func(p visitor.VisitFuncParams) (string, interface{}) { + switch node := p.Node.(type) { + case *ast.Name: + visited = append(visited, []interface{}{"leave", node.Kind, node.Value}) + case ast.Node: + visited = append(visited, []interface{}{"leave", node.GetKind(), nil}) + default: + visited = append(visited, []interface{}{"leave", nil, nil}) + } + return visitor.ActionNoChange, nil + }, + } + + _ = visitor.Visit(astDoc, visitor.VisitInParallel(v), nil) + + if !reflect.DeepEqual(visited, expectedVisited) { + t.Fatalf("Unexpected result, Diff: %v", testutil.Diff(expectedVisited, visited)) + } +} + +func TestVisitor_VisitInParallel_AllowsEarlyExitFromDifferentPoints(t *testing.T) { + + visited := []interface{}{} + + query := `{ a { y }, b { x } }` + astDoc := parse(t, query) + + expectedVisited := []interface{}{ + []interface{}{"break-a", "enter", "Document", nil}, + []interface{}{"break-b", "enter", "Document", nil}, + []interface{}{"break-a", "enter", "OperationDefinition", nil}, + []interface{}{"break-b", "enter", "OperationDefinition", nil}, + []interface{}{"break-a", "enter", "SelectionSet", nil}, + []interface{}{"break-b", "enter", "SelectionSet", nil}, + []interface{}{"break-a", "enter", "Field", nil}, + []interface{}{"break-b", "enter", "Field", nil}, + []interface{}{"break-a", "enter", "Name", "a"}, + []interface{}{"break-b", "enter", "Name", "a"}, + []interface{}{"break-b", "leave", "Name", "a"}, + []interface{}{"break-b", "enter", "SelectionSet", nil}, + []interface{}{"break-b", "enter", "Field", nil}, + []interface{}{"break-b", "enter", "Name", "y"}, + []interface{}{"break-b", "leave", "Name", "y"}, + []interface{}{"break-b", "leave", "Field", nil}, + []interface{}{"break-b", "leave", "SelectionSet", nil}, + []interface{}{"break-b", "leave", "Field", nil}, + []interface{}{"break-b", "enter", "Field", nil}, + []interface{}{"break-b", "enter", "Name", "b"}, + } + + v := &visitor.VisitorOptions{ + Enter: func(p visitor.VisitFuncParams) (string, interface{}) { + switch node := p.Node.(type) { + case *ast.Name: + visited = append(visited, []interface{}{"break-a", "enter", node.Kind, node.Value}) + if node != nil && node.Value == "a" { + return visitor.ActionBreak, nil + } + case ast.Node: + visited = append(visited, []interface{}{"break-a", "enter", node.GetKind(), nil}) + default: + visited = append(visited, []interface{}{"break-a", "enter", nil, nil}) + } + return visitor.ActionNoChange, nil + }, + Leave: func(p visitor.VisitFuncParams) (string, interface{}) { + switch node := p.Node.(type) { + case *ast.Name: + visited = append(visited, []interface{}{"break-a", "leave", node.Kind, node.Value}) + case ast.Node: + visited = append(visited, []interface{}{"break-a", "leave", node.GetKind(), nil}) + default: + visited = append(visited, []interface{}{"break-a", "leave", nil, nil}) + } + return visitor.ActionNoChange, nil + }, + } + + v2 := &visitor.VisitorOptions{ + Enter: func(p visitor.VisitFuncParams) (string, interface{}) { + switch node := p.Node.(type) { + case *ast.Name: + visited = append(visited, []interface{}{"break-b", "enter", node.Kind, node.Value}) + if node != nil && node.Value == "b" { + return visitor.ActionBreak, nil + } + case ast.Node: + visited = append(visited, []interface{}{"break-b", "enter", node.GetKind(), nil}) + default: + visited = append(visited, []interface{}{"break-b", "enter", nil, nil}) + } + return visitor.ActionNoChange, nil + }, + Leave: func(p visitor.VisitFuncParams) (string, interface{}) { + switch node := p.Node.(type) { + case *ast.Name: + visited = append(visited, []interface{}{"break-b", "leave", node.Kind, node.Value}) + case ast.Node: + visited = append(visited, []interface{}{"break-b", "leave", node.GetKind(), nil}) + default: + visited = append(visited, []interface{}{"break-b", "leave", nil, nil}) + } + return visitor.ActionNoChange, nil + }, + } + + _ = visitor.Visit(astDoc, visitor.VisitInParallel(v, v2), nil) + + if !reflect.DeepEqual(visited, expectedVisited) { + t.Fatalf("Unexpected result, Diff: %v", testutil.Diff(expectedVisited, visited)) + } +} + +func TestVisitor_VisitInParallel_AllowsEarlyExitWhileLeaving(t *testing.T) { + + visited := []interface{}{} + + query := `{ a, b { x }, c }` + astDoc := parse(t, query) + + expectedVisited := []interface{}{ + []interface{}{"enter", "Document", nil}, + []interface{}{"enter", "OperationDefinition", nil}, + []interface{}{"enter", "SelectionSet", nil}, + []interface{}{"enter", "Field", nil}, + []interface{}{"enter", "Name", "a"}, + []interface{}{"leave", "Name", "a"}, + []interface{}{"leave", "Field", nil}, + []interface{}{"enter", "Field", nil}, + []interface{}{"enter", "Name", "b"}, + []interface{}{"leave", "Name", "b"}, + []interface{}{"enter", "SelectionSet", nil}, + []interface{}{"enter", "Field", nil}, + []interface{}{"enter", "Name", "x"}, + []interface{}{"leave", "Name", "x"}, + } + + v := &visitor.VisitorOptions{ + Enter: func(p visitor.VisitFuncParams) (string, interface{}) { + switch node := p.Node.(type) { + case *ast.Name: + visited = append(visited, []interface{}{"enter", node.Kind, node.Value}) + case ast.Node: + visited = append(visited, []interface{}{"enter", node.GetKind(), nil}) + default: + visited = append(visited, []interface{}{"enter", nil, nil}) + } + return visitor.ActionNoChange, nil + }, + Leave: func(p visitor.VisitFuncParams) (string, interface{}) { + switch node := p.Node.(type) { + case *ast.Name: + visited = append(visited, []interface{}{"leave", node.Kind, node.Value}) + if node.Value == "x" { + return visitor.ActionBreak, nil + } + case ast.Node: + visited = append(visited, []interface{}{"leave", node.GetKind(), nil}) + default: + visited = append(visited, []interface{}{"leave", nil, nil}) + } + return visitor.ActionNoChange, nil + }, + } + + _ = visitor.Visit(astDoc, visitor.VisitInParallel(v), nil) + + if !reflect.DeepEqual(visited, expectedVisited) { + t.Fatalf("Unexpected result, Diff: %v", testutil.Diff(expectedVisited, visited)) + } +} + +func TestVisitor_VisitInParallel_AllowsEarlyExitFromLeavingDifferentPoints(t *testing.T) { + + visited := []interface{}{} + + query := `{ a { y }, b { x } }` + astDoc := parse(t, query) + + expectedVisited := []interface{}{ + []interface{}{"break-a", "enter", "Document", nil}, + []interface{}{"break-b", "enter", "Document", nil}, + []interface{}{"break-a", "enter", "OperationDefinition", nil}, + []interface{}{"break-b", "enter", "OperationDefinition", nil}, + []interface{}{"break-a", "enter", "SelectionSet", nil}, + []interface{}{"break-b", "enter", "SelectionSet", nil}, + []interface{}{"break-a", "enter", "Field", nil}, + []interface{}{"break-b", "enter", "Field", nil}, + []interface{}{"break-a", "enter", "Name", "a"}, + []interface{}{"break-b", "enter", "Name", "a"}, + []interface{}{"break-a", "leave", "Name", "a"}, + []interface{}{"break-b", "leave", "Name", "a"}, + []interface{}{"break-a", "enter", "SelectionSet", nil}, + []interface{}{"break-b", "enter", "SelectionSet", nil}, + []interface{}{"break-a", "enter", "Field", nil}, + []interface{}{"break-b", "enter", "Field", nil}, + []interface{}{"break-a", "enter", "Name", "y"}, + []interface{}{"break-b", "enter", "Name", "y"}, + []interface{}{"break-a", "leave", "Name", "y"}, + []interface{}{"break-b", "leave", "Name", "y"}, + []interface{}{"break-a", "leave", "Field", nil}, + []interface{}{"break-b", "leave", "Field", nil}, + []interface{}{"break-a", "leave", "SelectionSet", nil}, + []interface{}{"break-b", "leave", "SelectionSet", nil}, + []interface{}{"break-a", "leave", "Field", nil}, + []interface{}{"break-b", "leave", "Field", nil}, + []interface{}{"break-b", "enter", "Field", nil}, + []interface{}{"break-b", "enter", "Name", "b"}, + []interface{}{"break-b", "leave", "Name", "b"}, + []interface{}{"break-b", "enter", "SelectionSet", nil}, + []interface{}{"break-b", "enter", "Field", nil}, + []interface{}{"break-b", "enter", "Name", "x"}, + []interface{}{"break-b", "leave", "Name", "x"}, + []interface{}{"break-b", "leave", "Field", nil}, + []interface{}{"break-b", "leave", "SelectionSet", nil}, + []interface{}{"break-b", "leave", "Field", nil}, + } + + v := &visitor.VisitorOptions{ + Enter: func(p visitor.VisitFuncParams) (string, interface{}) { + switch node := p.Node.(type) { + case *ast.Name: + visited = append(visited, []interface{}{"break-a", "enter", node.Kind, node.Value}) + case ast.Node: + visited = append(visited, []interface{}{"break-a", "enter", node.GetKind(), nil}) + default: + visited = append(visited, []interface{}{"break-a", "enter", nil, nil}) + } + return visitor.ActionNoChange, nil + }, + Leave: func(p visitor.VisitFuncParams) (string, interface{}) { + switch node := p.Node.(type) { + case *ast.Field: + visited = append(visited, []interface{}{"break-a", "leave", node.GetKind(), nil}) + if node.Name != nil && node.Name.Value == "a" { + return visitor.ActionBreak, nil + } + case *ast.Name: + visited = append(visited, []interface{}{"break-a", "leave", node.Kind, node.Value}) + case ast.Node: + visited = append(visited, []interface{}{"break-a", "leave", node.GetKind(), nil}) + default: + visited = append(visited, []interface{}{"break-a", "leave", nil, nil}) + } + return visitor.ActionNoChange, nil + }, + } + + v2 := &visitor.VisitorOptions{ + Enter: func(p visitor.VisitFuncParams) (string, interface{}) { + switch node := p.Node.(type) { + case *ast.Name: + visited = append(visited, []interface{}{"break-b", "enter", node.Kind, node.Value}) + case ast.Node: + visited = append(visited, []interface{}{"break-b", "enter", node.GetKind(), nil}) + default: + visited = append(visited, []interface{}{"break-b", "enter", nil, nil}) + } + return visitor.ActionNoChange, nil + }, + Leave: func(p visitor.VisitFuncParams) (string, interface{}) { + switch node := p.Node.(type) { + case *ast.Field: + visited = append(visited, []interface{}{"break-b", "leave", node.GetKind(), nil}) + if node.Name != nil && node.Name.Value == "b" { + return visitor.ActionBreak, nil + } + case *ast.Name: + visited = append(visited, []interface{}{"break-b", "leave", node.Kind, node.Value}) + case ast.Node: + visited = append(visited, []interface{}{"break-b", "leave", node.GetKind(), nil}) + default: + visited = append(visited, []interface{}{"break-b", "leave", nil, nil}) + } + return visitor.ActionNoChange, nil + }, + } + + _ = visitor.Visit(astDoc, visitor.VisitInParallel(v, v2), nil) + + if !reflect.DeepEqual(visited, expectedVisited) { + t.Fatalf("Unexpected result, Diff: %v", testutil.Diff(expectedVisited, visited)) + } +} + +func TestVisitor_VisitInParallel_AllowsForEditingOnEnter(t *testing.T) { + + visited := []interface{}{} + + query := `{ a, b, c { a, b, c } }` + astDoc := parse(t, query) + + expectedVisited := []interface{}{ + []interface{}{"enter", "Document", nil}, + []interface{}{"enter", "OperationDefinition", nil}, + []interface{}{"enter", "SelectionSet", nil}, + []interface{}{"enter", "Field", nil}, + []interface{}{"enter", "Name", "a"}, + []interface{}{"leave", "Name", "a"}, + []interface{}{"leave", "Field", nil}, + []interface{}{"enter", "Field", nil}, + []interface{}{"enter", "Name", "c"}, + []interface{}{"leave", "Name", "c"}, + []interface{}{"enter", "SelectionSet", nil}, + []interface{}{"enter", "Field", nil}, + []interface{}{"enter", "Name", "a"}, + []interface{}{"leave", "Name", "a"}, + []interface{}{"leave", "Field", nil}, + []interface{}{"enter", "Field", nil}, + []interface{}{"enter", "Name", "c"}, + []interface{}{"leave", "Name", "c"}, + []interface{}{"leave", "Field", nil}, + []interface{}{"leave", "SelectionSet", nil}, + []interface{}{"leave", "Field", nil}, + []interface{}{"leave", "SelectionSet", nil}, + []interface{}{"leave", "OperationDefinition", nil}, + []interface{}{"leave", "Document", nil}, + } + + v := &visitor.VisitorOptions{ + Enter: func(p visitor.VisitFuncParams) (string, interface{}) { + switch node := p.Node.(type) { + case *ast.Field: + if node != nil && node.Name != nil && node.Name.Value == "b" { + return visitor.ActionUpdate, nil + } + } + return visitor.ActionNoChange, nil + }, + } + + v2 := &visitor.VisitorOptions{ + Enter: func(p visitor.VisitFuncParams) (string, interface{}) { + switch node := p.Node.(type) { + case *ast.Name: + visited = append(visited, []interface{}{"enter", node.Kind, node.Value}) + case ast.Node: + visited = append(visited, []interface{}{"enter", node.GetKind(), nil}) + default: + visited = append(visited, []interface{}{"enter", nil, nil}) + } + return visitor.ActionNoChange, nil + }, + Leave: func(p visitor.VisitFuncParams) (string, interface{}) { + switch node := p.Node.(type) { + case *ast.Name: + visited = append(visited, []interface{}{"leave", node.Kind, node.Value}) + case ast.Node: + visited = append(visited, []interface{}{"leave", node.GetKind(), nil}) + default: + visited = append(visited, []interface{}{"leave", nil, nil}) + } + return visitor.ActionNoChange, nil + }, + } + + _ = visitor.Visit(astDoc, visitor.VisitInParallel(v, v2), nil) + + if !reflect.DeepEqual(visited, expectedVisited) { + t.Fatalf("Unexpected result, Diff: %v", testutil.Diff(expectedVisited, visited)) + } +} + +func TestVisitor_VisitInParallel_AllowsForEditingOnLeave(t *testing.T) { + + visited := []interface{}{} + + query := `{ a, b, c { a, b, c } }` + astDoc := parse(t, query) + + expectedVisited := []interface{}{ + []interface{}{"enter", "Document", nil}, + []interface{}{"enter", "OperationDefinition", nil}, + []interface{}{"enter", "SelectionSet", nil}, + []interface{}{"enter", "Field", nil}, + []interface{}{"enter", "Name", "a"}, + []interface{}{"leave", "Name", "a"}, + []interface{}{"leave", "Field", nil}, + []interface{}{"enter", "Field", nil}, + []interface{}{"enter", "Name", "b"}, + []interface{}{"leave", "Name", "b"}, + []interface{}{"enter", "Field", nil}, + []interface{}{"enter", "Name", "c"}, + []interface{}{"leave", "Name", "c"}, + []interface{}{"enter", "SelectionSet", nil}, + []interface{}{"enter", "Field", nil}, + []interface{}{"enter", "Name", "a"}, + []interface{}{"leave", "Name", "a"}, + []interface{}{"leave", "Field", nil}, + []interface{}{"enter", "Field", nil}, + []interface{}{"enter", "Name", "b"}, + []interface{}{"leave", "Name", "b"}, + []interface{}{"enter", "Field", nil}, + []interface{}{"enter", "Name", "c"}, + []interface{}{"leave", "Name", "c"}, + []interface{}{"leave", "Field", nil}, + []interface{}{"leave", "SelectionSet", nil}, + []interface{}{"leave", "Field", nil}, + []interface{}{"leave", "SelectionSet", nil}, + []interface{}{"leave", "OperationDefinition", nil}, + []interface{}{"leave", "Document", nil}, + } + + v := &visitor.VisitorOptions{ + Leave: func(p visitor.VisitFuncParams) (string, interface{}) { + switch node := p.Node.(type) { + case *ast.Field: + if node != nil && node.Name != nil && node.Name.Value == "b" { + return visitor.ActionUpdate, nil + } + } + return visitor.ActionNoChange, nil + }, + } + + v2 := &visitor.VisitorOptions{ + Enter: func(p visitor.VisitFuncParams) (string, interface{}) { + switch node := p.Node.(type) { + case *ast.Name: + visited = append(visited, []interface{}{"enter", node.Kind, node.Value}) + case ast.Node: + visited = append(visited, []interface{}{"enter", node.GetKind(), nil}) + default: + visited = append(visited, []interface{}{"enter", nil, nil}) + } + return visitor.ActionNoChange, nil + }, + Leave: func(p visitor.VisitFuncParams) (string, interface{}) { + switch node := p.Node.(type) { + case *ast.Name: + visited = append(visited, []interface{}{"leave", node.Kind, node.Value}) + case ast.Node: + visited = append(visited, []interface{}{"leave", node.GetKind(), nil}) + default: + visited = append(visited, []interface{}{"leave", nil, nil}) + } + return visitor.ActionNoChange, nil + }, + } + + editedAST := visitor.Visit(astDoc, visitor.VisitInParallel(v, v2), nil) + + if !reflect.DeepEqual(visited, expectedVisited) { + t.Fatalf("Unexpected result, Diff: %v", testutil.Diff(expectedVisited, visited)) + } + + expectedEditedAST := parse(t, `{ a, c { a, c } }`) + if !reflect.DeepEqual(editedAST, expectedEditedAST) { + t.Fatalf("Unexpected result, Diff: %v", testutil.Diff(editedAST, expectedEditedAST)) + } +} + +func TestVisitor_VisitWithTypeInfo_MaintainsTypeInfoDuringVisit(t *testing.T) { + + visited := []interface{}{} + + typeInfo := graphql.NewTypeInfo(&graphql.TypeInfoConfig{ + Schema: testutil.TestSchema, + }) + + query := `{ human(id: 4) { name, pets { name }, unknown } }` + astDoc := parse(t, query) + + expectedVisited := []interface{}{ + []interface{}{"enter", "Document", nil, nil, nil, nil}, + []interface{}{"enter", "OperationDefinition", nil, nil, "QueryRoot", nil}, + []interface{}{"enter", "SelectionSet", nil, "QueryRoot", "QueryRoot", nil}, + []interface{}{"enter", "Field", nil, "QueryRoot", "Human", nil}, + []interface{}{"enter", "Name", "human", "QueryRoot", "Human", nil}, + []interface{}{"leave", "Name", "human", "QueryRoot", "Human", nil}, + []interface{}{"enter", "Argument", nil, "QueryRoot", "Human", "ID"}, + []interface{}{"enter", "Name", "id", "QueryRoot", "Human", "ID"}, + []interface{}{"leave", "Name", "id", "QueryRoot", "Human", "ID"}, + []interface{}{"enter", "IntValue", nil, "QueryRoot", "Human", "ID"}, + []interface{}{"leave", "IntValue", nil, "QueryRoot", "Human", "ID"}, + []interface{}{"leave", "Argument", nil, "QueryRoot", "Human", "ID"}, + []interface{}{"enter", "SelectionSet", nil, "Human", "Human", nil}, + []interface{}{"enter", "Field", nil, "Human", "String", nil}, + []interface{}{"enter", "Name", "name", "Human", "String", nil}, + []interface{}{"leave", "Name", "name", "Human", "String", nil}, + []interface{}{"leave", "Field", nil, "Human", "String", nil}, + []interface{}{"enter", "Field", nil, "Human", "[Pet]", nil}, + []interface{}{"enter", "Name", "pets", "Human", "[Pet]", nil}, + []interface{}{"leave", "Name", "pets", "Human", "[Pet]", nil}, + []interface{}{"enter", "SelectionSet", nil, "Pet", "[Pet]", nil}, + []interface{}{"enter", "Field", nil, "Pet", "String", nil}, + []interface{}{"enter", "Name", "name", "Pet", "String", nil}, + []interface{}{"leave", "Name", "name", "Pet", "String", nil}, + []interface{}{"leave", "Field", nil, "Pet", "String", nil}, + []interface{}{"leave", "SelectionSet", nil, "Pet", "[Pet]", nil}, + []interface{}{"leave", "Field", nil, "Human", "[Pet]", nil}, + []interface{}{"enter", "Field", nil, "Human", nil, nil}, + []interface{}{"enter", "Name", "unknown", "Human", nil, nil}, + []interface{}{"leave", "Name", "unknown", "Human", nil, nil}, + []interface{}{"leave", "Field", nil, "Human", nil, nil}, + []interface{}{"leave", "SelectionSet", nil, "Human", "Human", nil}, + []interface{}{"leave", "Field", nil, "QueryRoot", "Human", nil}, + []interface{}{"leave", "SelectionSet", nil, "QueryRoot", "QueryRoot", nil}, + []interface{}{"leave", "OperationDefinition", nil, nil, "QueryRoot", nil}, + []interface{}{"leave", "Document", nil, nil, nil, nil}, + } + + v := &visitor.VisitorOptions{ + Enter: func(p visitor.VisitFuncParams) (string, interface{}) { + var parentType interface{} + var ttype interface{} + var inputType interface{} + + if typeInfo.ParentType() != nil { + parentType = fmt.Sprintf("%v", typeInfo.ParentType()) + } + if typeInfo.Type() != nil { + ttype = fmt.Sprintf("%v", typeInfo.Type()) + } + if typeInfo.InputType() != nil { + inputType = fmt.Sprintf("%v", typeInfo.InputType()) + } + + switch node := p.Node.(type) { + case *ast.Name: + visited = append(visited, []interface{}{"enter", node.Kind, node.Value, parentType, ttype, inputType}) + case ast.Node: + visited = append(visited, []interface{}{"enter", node.GetKind(), nil, parentType, ttype, inputType}) + default: + visited = append(visited, []interface{}{"enter", nil, nil, parentType, ttype, inputType}) + } + return visitor.ActionNoChange, nil + }, + Leave: func(p visitor.VisitFuncParams) (string, interface{}) { + var parentType interface{} + var ttype interface{} + var inputType interface{} + + if typeInfo.ParentType() != nil { + parentType = fmt.Sprintf("%v", typeInfo.ParentType()) + } + if typeInfo.Type() != nil { + ttype = fmt.Sprintf("%v", typeInfo.Type()) + } + if typeInfo.InputType() != nil { + inputType = fmt.Sprintf("%v", typeInfo.InputType()) + } + + switch node := p.Node.(type) { + case *ast.Name: + visited = append(visited, []interface{}{"leave", node.Kind, node.Value, parentType, ttype, inputType}) + case ast.Node: + visited = append(visited, []interface{}{"leave", node.GetKind(), nil, parentType, ttype, inputType}) + default: + visited = append(visited, []interface{}{"leave", nil, nil, parentType, ttype, inputType}) + } + return visitor.ActionNoChange, nil + }, + } + + _ = visitor.Visit(astDoc, visitor.VisitWithTypeInfo(typeInfo, v), nil) + + if !reflect.DeepEqual(visited, expectedVisited) { + t.Fatalf("Unexpected result, Diff: %v", testutil.Diff(expectedVisited, visited)) + } + +} + +func TestVisitor_VisitWithTypeInfo_MaintainsTypeInfoDuringEdit(t *testing.T) { + + visited := []interface{}{} + + typeInfo := graphql.NewTypeInfo(&graphql.TypeInfoConfig{ + Schema: testutil.TestSchema, + }) + + astDoc := parse(t, `{ human(id: 4) { name, pets }, alien }`) + + expectedVisited := []interface{}{ + []interface{}{"enter", "Document", nil, nil, nil, nil}, + []interface{}{"enter", "OperationDefinition", nil, nil, "QueryRoot", nil}, + []interface{}{"enter", "SelectionSet", nil, "QueryRoot", "QueryRoot", nil}, + []interface{}{"enter", "Field", nil, "QueryRoot", "Human", nil}, + []interface{}{"enter", "Name", "human", "QueryRoot", "Human", nil}, + []interface{}{"leave", "Name", "human", "QueryRoot", "Human", nil}, + []interface{}{"enter", "Argument", nil, "QueryRoot", "Human", "ID"}, + []interface{}{"enter", "Name", "id", "QueryRoot", "Human", "ID"}, + []interface{}{"leave", "Name", "id", "QueryRoot", "Human", "ID"}, + []interface{}{"enter", "IntValue", nil, "QueryRoot", "Human", "ID"}, + []interface{}{"leave", "IntValue", nil, "QueryRoot", "Human", "ID"}, + []interface{}{"leave", "Argument", nil, "QueryRoot", "Human", "ID"}, + []interface{}{"enter", "SelectionSet", nil, "Human", "Human", nil}, + []interface{}{"enter", "Field", nil, "Human", "String", nil}, + []interface{}{"enter", "Name", "name", "Human", "String", nil}, + []interface{}{"leave", "Name", "name", "Human", "String", nil}, + []interface{}{"leave", "Field", nil, "Human", "String", nil}, + []interface{}{"enter", "Field", nil, "Human", "[Pet]", nil}, + []interface{}{"enter", "Name", "pets", "Human", "[Pet]", nil}, + []interface{}{"leave", "Name", "pets", "Human", "[Pet]", nil}, + []interface{}{"enter", "SelectionSet", nil, "Pet", "[Pet]", nil}, + []interface{}{"enter", "Field", nil, "Pet", "String!", nil}, + []interface{}{"enter", "Name", "__typename", "Pet", "String!", nil}, + []interface{}{"leave", "Name", "__typename", "Pet", "String!", nil}, + []interface{}{"leave", "Field", nil, "Pet", "String!", nil}, + []interface{}{"leave", "SelectionSet", nil, "Pet", "[Pet]", nil}, + []interface{}{"leave", "Field", nil, "Human", "[Pet]", nil}, + []interface{}{"leave", "SelectionSet", nil, "Human", "Human", nil}, + []interface{}{"leave", "Field", nil, "QueryRoot", "Human", nil}, + []interface{}{"enter", "Field", nil, "QueryRoot", "Alien", nil}, + []interface{}{"enter", "Name", "alien", "QueryRoot", "Alien", nil}, + []interface{}{"leave", "Name", "alien", "QueryRoot", "Alien", nil}, + []interface{}{"enter", "SelectionSet", nil, "Alien", "Alien", nil}, + []interface{}{"enter", "Field", nil, "Alien", "String!", nil}, + []interface{}{"enter", "Name", "__typename", "Alien", "String!", nil}, + []interface{}{"leave", "Name", "__typename", "Alien", "String!", nil}, + []interface{}{"leave", "Field", nil, "Alien", "String!", nil}, + []interface{}{"leave", "SelectionSet", nil, "Alien", "Alien", nil}, + []interface{}{"leave", "Field", nil, "QueryRoot", "Alien", nil}, + []interface{}{"leave", "SelectionSet", nil, "QueryRoot", "QueryRoot", nil}, + []interface{}{"leave", "OperationDefinition", nil, nil, "QueryRoot", nil}, + []interface{}{"leave", "Document", nil, nil, nil, nil}, + } + + v := &visitor.VisitorOptions{ + Enter: func(p visitor.VisitFuncParams) (string, interface{}) { + var parentType interface{} + var ttype interface{} + var inputType interface{} + + if typeInfo.ParentType() != nil { + parentType = fmt.Sprintf("%v", typeInfo.ParentType()) + } + if typeInfo.Type() != nil { + ttype = fmt.Sprintf("%v", typeInfo.Type()) + } + if typeInfo.InputType() != nil { + inputType = fmt.Sprintf("%v", typeInfo.InputType()) + } + + switch node := p.Node.(type) { + case *ast.Name: + visited = append(visited, []interface{}{"enter", node.Kind, node.Value, parentType, ttype, inputType}) + case *ast.Field: + visited = append(visited, []interface{}{"enter", node.GetKind(), nil, parentType, ttype, inputType}) + + // Make a query valid by adding missing selection sets. + if node.SelectionSet == nil && graphql.IsCompositeType(graphql.GetNamed(typeInfo.Type())) { + return visitor.ActionUpdate, ast.NewField(&ast.Field{ + Alias: node.Alias, + Name: node.Name, + Arguments: node.Arguments, + Directives: node.Directives, + SelectionSet: ast.NewSelectionSet(&ast.SelectionSet{ + Selections: []ast.Selection{ + ast.NewField(&ast.Field{ + Name: ast.NewName(&ast.Name{ + Value: "__typename", + }), + }), + }, + }), + }) + } + case ast.Node: + visited = append(visited, []interface{}{"enter", node.GetKind(), nil, parentType, ttype, inputType}) + default: + visited = append(visited, []interface{}{"enter", nil, nil, parentType, ttype, inputType}) + } + + return visitor.ActionNoChange, nil + }, + Leave: func(p visitor.VisitFuncParams) (string, interface{}) { + var parentType interface{} + var ttype interface{} + var inputType interface{} + + if typeInfo.ParentType() != nil { + parentType = fmt.Sprintf("%v", typeInfo.ParentType()) + } + if typeInfo.Type() != nil { + ttype = fmt.Sprintf("%v", typeInfo.Type()) + } + if typeInfo.InputType() != nil { + inputType = fmt.Sprintf("%v", typeInfo.InputType()) + } + + switch node := p.Node.(type) { + case *ast.Name: + visited = append(visited, []interface{}{"leave", node.Kind, node.Value, parentType, ttype, inputType}) + case ast.Node: + visited = append(visited, []interface{}{"leave", node.GetKind(), nil, parentType, ttype, inputType}) + default: + visited = append(visited, []interface{}{"leave", nil, nil, parentType, ttype, inputType}) + } + return visitor.ActionNoChange, nil + }, + } + + editedAST := visitor.Visit(astDoc, visitor.VisitWithTypeInfo(typeInfo, v), nil) + + editedASTQuery := printer.Print(editedAST.(ast.Node)) + expectedEditedASTQuery := printer.Print(parse(t, `{ human(id: 4) { name, pets { __typename } }, alien { __typename } }`)) + + if !reflect.DeepEqual(editedASTQuery, expectedEditedASTQuery) { + t.Fatalf("Unexpected result, Diff: %v", testutil.Diff(editedASTQuery, expectedEditedASTQuery)) + } + if !reflect.DeepEqual(visited, expectedVisited) { + t.Fatalf("Unexpected result, Diff: %v", testutil.Diff(expectedVisited, visited)) + } + +} diff --git a/lists_test.go b/lists_test.go index 4bc47cd7..4c30b500 100644 --- a/lists_test.go +++ b/lists_test.go @@ -761,3 +761,22 @@ func TestLists_NonNullListOfNonNullArrayOfFunc_ContainsNulls(t *testing.T) { } checkList(t, ttype, data, expected) } + +func TestLists_UserErrorExpectIterableButDidNotGetOne(t *testing.T) { + ttype := graphql.NewList(graphql.Int) + data := "Not an iterable" + expected := &graphql.Result{ + Data: map[string]interface{}{ + "nest": map[string]interface{}{ + "test": nil, + }, + }, + Errors: []gqlerrors.FormattedError{ + gqlerrors.FormattedError{ + Message: "User Error: expected iterable, but did not find one for field DataType.test.", + Locations: []location.SourceLocation{}, + }, + }, + } + checkList(t, ttype, data, expected) +} diff --git a/located.go b/located.go index e7a4cdc0..6ed8ec83 100644 --- a/located.go +++ b/located.go @@ -1,17 +1,21 @@ package graphql import ( + "errors" "github.com/graphql-go/graphql/gqlerrors" "github.com/graphql-go/graphql/language/ast" ) func NewLocatedError(err interface{}, nodes []ast.Node) *gqlerrors.Error { + var origError error message := "An unknown error occurred." if err, ok := err.(error); ok { message = err.Error() + origError = err } if err, ok := err.(string); ok { message = err + origError = errors.New(err) } stack := message return gqlerrors.NewError( @@ -20,6 +24,7 @@ func NewLocatedError(err interface{}, nodes []ast.Node) *gqlerrors.Error { stack, nil, []int{}, + origError, ) } diff --git a/rules.go b/rules.go index 80f10754..1b518846 100644 --- a/rules.go +++ b/rules.go @@ -11,9 +11,7 @@ import ( "strings" ) -/** - * SpecifiedRules set includes all validation rules defined by the GraphQL spec. - */ +// SpecifiedRules set includes all validation rules defined by the GraphQL spec. var SpecifiedRules = []ValidationRuleFn{ ArgumentsOfCorrectTypeRule, DefaultValuesOfCorrectTypeRule, @@ -34,57 +32,69 @@ var SpecifiedRules = []ValidationRuleFn{ ScalarLeafsRule, UniqueArgumentNamesRule, UniqueFragmentNamesRule, + UniqueInputFieldNamesRule, UniqueOperationNamesRule, + UniqueVariableNamesRule, VariablesAreInputTypesRule, VariablesInAllowedPositionRule, } type ValidationRuleInstance struct { - VisitorOpts *visitor.VisitorOptions - VisitSpreadFragments bool + VisitorOpts *visitor.VisitorOptions } type ValidationRuleFn func(context *ValidationContext) *ValidationRuleInstance -func newValidationRuleError(message string, nodes []ast.Node) (string, error) { - return visitor.ActionNoChange, gqlerrors.NewError( +func newValidationError(message string, nodes []ast.Node) *gqlerrors.Error { + return gqlerrors.NewError( message, nodes, "", nil, []int{}, + nil, // TODO: this is interim, until we port "better-error-messages-for-inputs" ) } -/** - * ArgumentsOfCorrectTypeRule - * Argument values of correct type - * - * A GraphQL document is only valid if all field argument literal values are - * of the type expected by their position. - */ +func reportError(context *ValidationContext, message string, nodes []ast.Node) (string, interface{}) { + context.ReportError(newValidationError(message, nodes)) + return visitor.ActionNoChange, nil +} + +// ArgumentsOfCorrectTypeRule Argument values of correct type +// +// A GraphQL document is only valid if all field argument literal values are +// of the type expected by their position. func ArgumentsOfCorrectTypeRule(context *ValidationContext) *ValidationRuleInstance { visitorOpts := &visitor.VisitorOptions{ KindFuncMap: map[string]visitor.NamedVisitFuncs{ kinds.Argument: visitor.NamedVisitFuncs{ Kind: func(p visitor.VisitFuncParams) (string, interface{}) { - var action = visitor.ActionNoChange - var result interface{} if argAST, ok := p.Node.(*ast.Argument); ok { value := argAST.Value argDef := context.Argument() - if argDef != nil && !isValidLiteralValue(argDef.Type, value) { - argNameValue := "" - if argAST.Name != nil { - argNameValue = argAST.Name.Value + if argDef != nil { + isValid, messages := isValidLiteralValue(argDef.Type, value) + if !isValid { + argNameValue := "" + if argAST.Name != nil { + argNameValue = argAST.Name.Value + } + + messagesStr := "" + if len(messages) > 0 { + messagesStr = "\n" + strings.Join(messages, "\n") + } + reportError( + context, + fmt.Sprintf(`Argument "%v" has invalid value %v.%v`, + argNameValue, printer.Print(value), messagesStr), + []ast.Node{value}, + ) } - return newValidationRuleError( - fmt.Sprintf(`Argument "%v" expected type "%v" but got: %v.`, - argNameValue, argDef.Type, printer.Print(value)), - []ast.Node{value}, - ) + } } - return action, result + return visitor.ActionSkip, nil }, }, }, @@ -94,20 +104,15 @@ func ArgumentsOfCorrectTypeRule(context *ValidationContext) *ValidationRuleInsta } } -/** - * DefaultValuesOfCorrectTypeRule - * Variable default values of correct type - * - * A GraphQL document is only valid if all variable default values are of the - * type expected by their definition. - */ +// DefaultValuesOfCorrectTypeRule Variable default values of correct type +// +// A GraphQL document is only valid if all variable default values are of the +// type expected by their definition. func DefaultValuesOfCorrectTypeRule(context *ValidationContext) *ValidationRuleInstance { visitorOpts := &visitor.VisitorOptions{ KindFuncMap: map[string]visitor.NamedVisitFuncs{ kinds.VariableDefinition: visitor.NamedVisitFuncs{ Kind: func(p visitor.VisitFuncParams) (string, interface{}) { - var action = visitor.ActionNoChange - var result interface{} if varDefAST, ok := p.Node.(*ast.VariableDefinition); ok { name := "" if varDefAST.Variable != nil && varDefAST.Variable.Name != nil { @@ -117,21 +122,38 @@ func DefaultValuesOfCorrectTypeRule(context *ValidationContext) *ValidationRuleI ttype := context.InputType() if ttype, ok := ttype.(*NonNull); ok && defaultValue != nil { - return newValidationRuleError( + reportError( + context, fmt.Sprintf(`Variable "$%v" of type "%v" is required and will not use the default value. Perhaps you meant to use type "%v".`, name, ttype, ttype.OfType), []ast.Node{defaultValue}, ) } - if ttype != nil && defaultValue != nil && !isValidLiteralValue(ttype, defaultValue) { - return newValidationRuleError( - fmt.Sprintf(`Variable "$%v" of type "%v" has invalid default value: %v.`, - name, ttype, printer.Print(defaultValue)), + isValid, messages := isValidLiteralValue(ttype, defaultValue) + if ttype != nil && defaultValue != nil && !isValid { + messagesStr := "" + if len(messages) > 0 { + messagesStr = "\n" + strings.Join(messages, "\n") + } + reportError( + context, + fmt.Sprintf(`Variable "$%v" has invalid default value: %v.%v`, + name, printer.Print(defaultValue), messagesStr), []ast.Node{defaultValue}, ) } } - return action, result + return visitor.ActionSkip, nil + }, + }, + kinds.SelectionSet: visitor.NamedVisitFuncs{ + Kind: func(p visitor.VisitFuncParams) (string, interface{}) { + return visitor.ActionSkip, nil + }, + }, + kinds.FragmentDefinition: visitor.NamedVisitFuncs{ + Kind: func(p visitor.VisitFuncParams) (string, interface{}) { + return visitor.ActionSkip, nil }, }, }, @@ -141,13 +163,36 @@ func DefaultValuesOfCorrectTypeRule(context *ValidationContext) *ValidationRuleI } } -/** - * FieldsOnCorrectTypeRule - * Fields on correct type - * - * A GraphQL document is only valid if all fields selected are defined by the - * parent type, or are an allowed meta field such as __typenamme - */ +func UndefinedFieldMessage(fieldName string, ttypeName string, suggestedTypes []string) string { + + quoteStrings := func(slice []string) []string { + quoted := []string{} + for _, s := range slice { + quoted = append(quoted, fmt.Sprintf(`"%v"`, s)) + } + return quoted + } + + // construct helpful (but long) message + message := fmt.Sprintf(`Cannot query field "%v" on type "%v".`, fieldName, ttypeName) + suggestions := strings.Join(quoteStrings(suggestedTypes), ", ") + const MaxLength = 5 + if len(suggestedTypes) > 0 { + if len(suggestedTypes) > MaxLength { + suggestions = strings.Join(quoteStrings(suggestedTypes[0:MaxLength]), ", ") + + fmt.Sprintf(`, and %v other types`, len(suggestedTypes)-MaxLength) + } + message = message + fmt.Sprintf(` However, this field exists on %v.`, suggestions) + message = message + ` Perhaps you meant to use an inline fragment?` + } + + return message +} + +// FieldsOnCorrectTypeRule Fields on correct type +// +// A GraphQL document is only valid if all fields selected are defined by the +// parent type, or are an allowed meta field such as __typenamme func FieldsOnCorrectTypeRule(context *ValidationContext) *ValidationRuleInstance { visitorOpts := &visitor.VisitorOptions{ KindFuncMap: map[string]visitor.NamedVisitFuncs{ @@ -161,13 +206,37 @@ func FieldsOnCorrectTypeRule(context *ValidationContext) *ValidationRuleInstance if ttype != nil { fieldDef := context.FieldDef() if fieldDef == nil { + // This isn't valid. Let's find suggestions, if any. + suggestedTypes := []string{} + nodeName := "" if node.Name != nil { nodeName = node.Name.Value } - return newValidationRuleError( - fmt.Sprintf(`Cannot query field "%v" on "%v".`, - nodeName, ttype.Name()), + + if ttype, ok := ttype.(Abstract); ok { + siblingInterfaces := getSiblingInterfacesIncludingField(ttype, nodeName) + implementations := getImplementationsIncludingField(ttype, nodeName) + suggestedMaps := map[string]bool{} + for _, s := range siblingInterfaces { + if _, ok := suggestedMaps[s]; !ok { + suggestedMaps[s] = true + suggestedTypes = append(suggestedTypes, s) + } + } + for _, s := range implementations { + if _, ok := suggestedMaps[s]; !ok { + suggestedMaps[s] = true + suggestedTypes = append(suggestedTypes, s) + } + } + } + + message := UndefinedFieldMessage(nodeName, ttype.Name(), suggestedTypes) + + reportError( + context, + message, []ast.Node{node}, ) } @@ -183,14 +252,90 @@ func FieldsOnCorrectTypeRule(context *ValidationContext) *ValidationRuleInstance } } -/** - * FragmentsOnCompositeTypesRule - * Fragments on composite type - * - * Fragments use a type condition to determine if they apply, since fragments - * can only be spread into a composite type (object, interface, or union), the - * type condition must also be a composite type. - */ +// Return implementations of `type` that include `fieldName` as a valid field. +func getImplementationsIncludingField(ttype Abstract, fieldName string) []string { + + result := []string{} + for _, t := range ttype.PossibleTypes() { + fields := t.Fields() + if _, ok := fields[fieldName]; ok { + result = append(result, fmt.Sprintf(`%v`, t.Name())) + } + } + + sort.Strings(result) + return result +} + +// Go through all of the implementations of type, and find other interaces +// that they implement. If those interfaces include `field` as a valid field, +// return them, sorted by how often the implementations include the other +// interface. +func getSiblingInterfacesIncludingField(ttype Abstract, fieldName string) []string { + implementingObjects := ttype.PossibleTypes() + + result := []string{} + suggestedInterfaceSlice := []*suggestedInterface{} + + // stores a map of interface name => index in suggestedInterfaceSlice + suggestedInterfaceMap := map[string]int{} + + for _, t := range implementingObjects { + for _, i := range t.Interfaces() { + if i == nil { + continue + } + fields := i.Fields() + if _, ok := fields[fieldName]; !ok { + continue + } + index, ok := suggestedInterfaceMap[i.Name()] + if !ok { + suggestedInterfaceSlice = append(suggestedInterfaceSlice, &suggestedInterface{ + name: i.Name(), + count: 0, + }) + index = len(suggestedInterfaceSlice) - 1 + } + if index < len(suggestedInterfaceSlice) { + s := suggestedInterfaceSlice[index] + if s.name == i.Name() { + s.count = s.count + 1 + } + } + } + } + sort.Sort(suggestedInterfaceSortedSlice(suggestedInterfaceSlice)) + + for _, s := range suggestedInterfaceSlice { + result = append(result, fmt.Sprintf(`%v`, s.name)) + } + return result + +} + +type suggestedInterface struct { + name string + count int +} + +type suggestedInterfaceSortedSlice []*suggestedInterface + +func (s suggestedInterfaceSortedSlice) Len() int { + return len(s) +} +func (s suggestedInterfaceSortedSlice) Swap(i, j int) { + s[i], s[j] = s[j], s[i] +} +func (s suggestedInterfaceSortedSlice) Less(i, j int) bool { + return s[i].count < s[j].count +} + +// FragmentsOnCompositeTypesRule Fragments on composite type +// +// Fragments use a type condition to determine if they apply, since fragments +// can only be spread into a composite type (object, interface, or union), the +// type condition must also be a composite type. func FragmentsOnCompositeTypesRule(context *ValidationContext) *ValidationRuleInstance { visitorOpts := &visitor.VisitorOptions{ KindFuncMap: map[string]visitor.NamedVisitFuncs{ @@ -198,8 +343,9 @@ func FragmentsOnCompositeTypesRule(context *ValidationContext) *ValidationRuleIn Kind: func(p visitor.VisitFuncParams) (string, interface{}) { if node, ok := p.Node.(*ast.InlineFragment); ok { ttype := context.Type() - if ttype != nil && !IsCompositeType(ttype) { - return newValidationRuleError( + if node.TypeCondition != nil && ttype != nil && !IsCompositeType(ttype) { + reportError( + context, fmt.Sprintf(`Fragment cannot condition on non composite type "%v".`, ttype), []ast.Node{node.TypeCondition}, ) @@ -217,7 +363,8 @@ func FragmentsOnCompositeTypesRule(context *ValidationContext) *ValidationRuleIn if node.Name != nil { nodeName = node.Name.Value } - return newValidationRuleError( + reportError( + context, fmt.Sprintf(`Fragment "%v" cannot condition on non composite type "%v".`, nodeName, printer.Print(node.TypeCondition)), []ast.Node{node.TypeCondition}, ) @@ -233,13 +380,10 @@ func FragmentsOnCompositeTypesRule(context *ValidationContext) *ValidationRuleIn } } -/** - * KnownArgumentNamesRule - * Known argument names - * - * A GraphQL field is only valid if all supplied arguments are defined by - * that field. - */ +// KnownArgumentNamesRule Known argument names +// +// A GraphQL field is only valid if all supplied arguments are defined by +// that field. func KnownArgumentNamesRule(context *ValidationContext) *ValidationRuleInstance { visitorOpts := &visitor.VisitorOptions{ KindFuncMap: map[string]visitor.NamedVisitFuncs{ @@ -255,7 +399,7 @@ func KnownArgumentNamesRule(context *ValidationContext) *ValidationRuleInstance if argumentOf == nil { return action, result } - if argumentOf.GetKind() == "Field" { + if argumentOf.GetKind() == kinds.Field { fieldDef := context.FieldDef() if fieldDef == nil { return action, result @@ -276,12 +420,13 @@ func KnownArgumentNamesRule(context *ValidationContext) *ValidationRuleInstance if parentType != nil { parentTypeName = parentType.Name() } - return newValidationRuleError( + reportError( + context, fmt.Sprintf(`Unknown argument "%v" on field "%v" of type "%v".`, nodeName, fieldDef.Name, parentTypeName), []ast.Node{node}, ) } - } else if argumentOf.GetKind() == "Directive" { + } else if argumentOf.GetKind() == kinds.Directive { directive := context.Directive() if directive == nil { return action, result @@ -297,7 +442,8 @@ func KnownArgumentNamesRule(context *ValidationContext) *ValidationRuleInstance } } if directiveArgDef == nil { - return newValidationRuleError( + reportError( + context, fmt.Sprintf(`Unknown argument "%v" on directive "@%v".`, nodeName, directive.Name), []ast.Node{node}, ) @@ -315,12 +461,10 @@ func KnownArgumentNamesRule(context *ValidationContext) *ValidationRuleInstance } } -/** - * Known directives - * - * A GraphQL document is only valid if all `@directives` are known by the - * schema and legally positioned. - */ +// KnownDirectivesRule Known directives +// +// A GraphQL document is only valid if all `@directives` are known by the +// schema and legally positioned. func KnownDirectivesRule(context *ValidationContext) *ValidationRuleInstance { visitorOpts := &visitor.VisitorOptions{ KindFuncMap: map[string]visitor.NamedVisitFuncs{ @@ -342,7 +486,8 @@ func KnownDirectivesRule(context *ValidationContext) *ValidationRuleInstance { } } if directiveDef == nil { - return newValidationRuleError( + return reportError( + context, fmt.Sprintf(`Unknown directive "%v".`, nodeName), []ast.Node{node}, ) @@ -357,13 +502,15 @@ func KnownDirectivesRule(context *ValidationContext) *ValidationRuleInstance { } if appliedTo.GetKind() == kinds.OperationDefinition && directiveDef.OnOperation == false { - return newValidationRuleError( + reportError( + context, fmt.Sprintf(`Directive "%v" may not be used on "%v".`, nodeName, "operation"), []ast.Node{node}, ) } if appliedTo.GetKind() == kinds.Field && directiveDef.OnField == false { - return newValidationRuleError( + reportError( + context, fmt.Sprintf(`Directive "%v" may not be used on "%v".`, nodeName, "field"), []ast.Node{node}, ) @@ -371,7 +518,8 @@ func KnownDirectivesRule(context *ValidationContext) *ValidationRuleInstance { if (appliedTo.GetKind() == kinds.FragmentSpread || appliedTo.GetKind() == kinds.InlineFragment || appliedTo.GetKind() == kinds.FragmentDefinition) && directiveDef.OnFragment == false { - return newValidationRuleError( + reportError( + context, fmt.Sprintf(`Directive "%v" may not be used on "%v".`, nodeName, "fragment"), []ast.Node{node}, ) @@ -388,13 +536,10 @@ func KnownDirectivesRule(context *ValidationContext) *ValidationRuleInstance { } } -/** - * KnownFragmentNamesRule - * Known fragment names - * - * A GraphQL document is only valid if all `...Fragment` fragment spreads refer - * to fragments defined in the same document. - */ +// KnownFragmentNamesRule Known fragment names +// +// A GraphQL document is only valid if all `...Fragment` fragment spreads refer +// to fragments defined in the same document. func KnownFragmentNamesRule(context *ValidationContext) *ValidationRuleInstance { visitorOpts := &visitor.VisitorOptions{ KindFuncMap: map[string]visitor.NamedVisitFuncs{ @@ -411,7 +556,8 @@ func KnownFragmentNamesRule(context *ValidationContext) *ValidationRuleInstance fragment := context.Fragment(fragmentName) if fragment == nil { - return newValidationRuleError( + reportError( + context, fmt.Sprintf(`Unknown fragment "%v".`, fragmentName), []ast.Node{node.Name}, ) @@ -427,16 +573,33 @@ func KnownFragmentNamesRule(context *ValidationContext) *ValidationRuleInstance } } -/** - * KnownTypeNamesRule - * Known type names - * - * A GraphQL document is only valid if referenced types (specifically - * variable definitions and fragment conditions) are defined by the type schema. - */ +// KnownTypeNamesRule Known type names +// +// A GraphQL document is only valid if referenced types (specifically +// variable definitions and fragment conditions) are defined by the type schema. func KnownTypeNamesRule(context *ValidationContext) *ValidationRuleInstance { visitorOpts := &visitor.VisitorOptions{ KindFuncMap: map[string]visitor.NamedVisitFuncs{ + kinds.ObjectDefinition: visitor.NamedVisitFuncs{ + Kind: func(p visitor.VisitFuncParams) (string, interface{}) { + return visitor.ActionSkip, nil + }, + }, + kinds.InterfaceDefinition: visitor.NamedVisitFuncs{ + Kind: func(p visitor.VisitFuncParams) (string, interface{}) { + return visitor.ActionSkip, nil + }, + }, + kinds.UnionDefinition: visitor.NamedVisitFuncs{ + Kind: func(p visitor.VisitFuncParams) (string, interface{}) { + return visitor.ActionSkip, nil + }, + }, + kinds.InputObjectDefinition: visitor.NamedVisitFuncs{ + Kind: func(p visitor.VisitFuncParams) (string, interface{}) { + return visitor.ActionSkip, nil + }, + }, kinds.Named: visitor.NamedVisitFuncs{ Kind: func(p visitor.VisitFuncParams) (string, interface{}) { if node, ok := p.Node.(*ast.Named); ok { @@ -447,7 +610,8 @@ func KnownTypeNamesRule(context *ValidationContext) *ValidationRuleInstance { } ttype := context.Schema().Type(typeNameValue) if ttype == nil { - return newValidationRuleError( + reportError( + context, fmt.Sprintf(`Unknown type "%v".`, typeNameValue), []ast.Node{node}, ) @@ -463,13 +627,10 @@ func KnownTypeNamesRule(context *ValidationContext) *ValidationRuleInstance { } } -/** - * LoneAnonymousOperationRule - * Lone anonymous operation - * - * A GraphQL document is only valid if when it contains an anonymous operation - * (the query short-hand) that it contains only that one operation definition. - */ +// LoneAnonymousOperationRule Lone anonymous operation +// +// A GraphQL document is only valid if when it contains an anonymous operation +// (the query short-hand) that it contains only that one operation definition. func LoneAnonymousOperationRule(context *ValidationContext) *ValidationRuleInstance { var operationCount = 0 visitorOpts := &visitor.VisitorOptions{ @@ -491,7 +652,8 @@ func LoneAnonymousOperationRule(context *ValidationContext) *ValidationRuleInsta Kind: func(p visitor.VisitFuncParams) (string, interface{}) { if node, ok := p.Node.(*ast.OperationDefinition); ok { if node.Name == nil && operationCount > 1 { - return newValidationRuleError( + reportError( + context, `This anonymous operation must be the only defined operation.`, []ast.Node{node}, ) @@ -528,97 +690,111 @@ func (set *nodeSet) Add(node ast.Node) bool { return true } -/** - * NoFragmentCyclesRule - */ +func CycleErrorMessage(fragName string, spreadNames []string) string { + via := "" + if len(spreadNames) > 0 { + via = " via " + strings.Join(spreadNames, ", ") + } + return fmt.Sprintf(`Cannot spread fragment "%v" within itself%v.`, fragName, via) +} + +// NoFragmentCyclesRule No fragment cycles func NoFragmentCyclesRule(context *ValidationContext) *ValidationRuleInstance { - // Gather all the fragment spreads ASTs for each fragment definition. - // Importantly this does not include inline fragments. - definitions := context.Document().Definitions - spreadsInFragment := map[string][]*ast.FragmentSpread{} - for _, node := range definitions { - if node.GetKind() == kinds.FragmentDefinition { - if node, ok := node.(*ast.FragmentDefinition); ok && node != nil { - nodeName := "" - if node.Name != nil { - nodeName = node.Name.Value + + // Tracks already visited fragments to maintain O(N) and to ensure that cycles + // are not redundantly reported. + visitedFrags := map[string]bool{} + + // Array of AST nodes used to produce meaningful errors + spreadPath := []*ast.FragmentSpread{} + + // Position in the spread path + spreadPathIndexByName := map[string]int{} + + // This does a straight-forward DFS to find cycles. + // It does not terminate when a cycle was found but continues to explore + // the graph to find all possible cycles. + var detectCycleRecursive func(fragment *ast.FragmentDefinition) + detectCycleRecursive = func(fragment *ast.FragmentDefinition) { + + fragmentName := "" + if fragment.Name != nil { + fragmentName = fragment.Name.Value + } + visitedFrags[fragmentName] = true + + spreadNodes := context.FragmentSpreads(fragment) + if len(spreadNodes) == 0 { + return + } + + spreadPathIndexByName[fragmentName] = len(spreadPath) + + for _, spreadNode := range spreadNodes { + + spreadName := "" + if spreadNode.Name != nil { + spreadName = spreadNode.Name.Value + } + cycleIndex, ok := spreadPathIndexByName[spreadName] + if !ok { + spreadPath = append(spreadPath, spreadNode) + if visited, ok := visitedFrags[spreadName]; !ok || !visited { + spreadFragment := context.Fragment(spreadName) + if spreadFragment != nil { + detectCycleRecursive(spreadFragment) + } + } + spreadPath = spreadPath[:len(spreadPath)-1] + } else { + cyclePath := spreadPath[cycleIndex:] + + spreadNames := []string{} + for _, s := range cyclePath { + name := "" + if s.Name != nil { + name = s.Name.Value + } + spreadNames = append(spreadNames, name) } - spreadsInFragment[nodeName] = gatherSpreads(node) + + nodes := []ast.Node{} + for _, c := range cyclePath { + nodes = append(nodes, c) + } + nodes = append(nodes, spreadNode) + + reportError( + context, + CycleErrorMessage(spreadName, spreadNames), + nodes, + ) } + } + delete(spreadPathIndexByName, fragmentName) + } - // Tracks spreads known to lead to cycles to ensure that cycles are not - // redundantly reported. - knownToLeadToCycle := newNodeSet() visitorOpts := &visitor.VisitorOptions{ KindFuncMap: map[string]visitor.NamedVisitFuncs{ + kinds.OperationDefinition: visitor.NamedVisitFuncs{ + Kind: func(p visitor.VisitFuncParams) (string, interface{}) { + return visitor.ActionSkip, nil + }, + }, kinds.FragmentDefinition: visitor.NamedVisitFuncs{ Kind: func(p visitor.VisitFuncParams) (string, interface{}) { if node, ok := p.Node.(*ast.FragmentDefinition); ok && node != nil { - errors := []error{} - spreadPath := []*ast.FragmentSpread{} - initialName := "" + nodeName := "" if node.Name != nil { - initialName = node.Name.Value - } - var detectCycleRecursive func(fragmentName string) - detectCycleRecursive = func(fragmentName string) { - spreadNodes, _ := spreadsInFragment[fragmentName] - for _, spreadNode := range spreadNodes { - if knownToLeadToCycle.Has(spreadNode) { - continue - } - spreadNodeName := "" - if spreadNode.Name != nil { - spreadNodeName = spreadNode.Name.Value - } - if spreadNodeName == initialName { - cyclePath := []ast.Node{} - for _, path := range spreadPath { - cyclePath = append(cyclePath, path) - } - cyclePath = append(cyclePath, spreadNode) - for _, spread := range cyclePath { - knownToLeadToCycle.Add(spread) - } - via := "" - spreadNames := []string{} - for _, s := range spreadPath { - if s.Name != nil { - spreadNames = append(spreadNames, s.Name.Value) - } - } - if len(spreadNames) > 0 { - via = " via " + strings.Join(spreadNames, ", ") - } - _, err := newValidationRuleError( - fmt.Sprintf(`Cannot spread fragment "%v" within itself%v.`, initialName, via), - cyclePath, - ) - errors = append(errors, err) - continue - } - spreadPathHasCurrentNode := false - for _, spread := range spreadPath { - if spread == spreadNode { - spreadPathHasCurrentNode = true - } - } - if spreadPathHasCurrentNode { - continue - } - spreadPath = append(spreadPath, spreadNode) - detectCycleRecursive(spreadNodeName) - _, spreadPath = spreadPath[len(spreadPath)-1], spreadPath[:len(spreadPath)-1] - } + nodeName = node.Name.Value } - detectCycleRecursive(initialName) - if len(errors) > 0 { - return visitor.ActionNoChange, errors + if _, ok := visitedFrags[nodeName]; !ok { + detectCycleRecursive(node) } } - return visitor.ActionNoChange, nil + return visitor.ActionSkip, nil }, }, }, @@ -628,83 +804,66 @@ func NoFragmentCyclesRule(context *ValidationContext) *ValidationRuleInstance { } } -/** - * NoUndefinedVariables - * No undefined variables - * - * A GraphQL operation is only valid if all variables encountered, both directly - * and via fragment spreads, are defined by that operation. - */ +func UndefinedVarMessage(varName string, opName string) string { + if opName != "" { + return fmt.Sprintf(`Variable "$%v" is not defined by operation "%v".`, varName, opName) + } + return fmt.Sprintf(`Variable "$%v" is not defined.`, varName) +} + +// NoUndefinedVariablesRule No undefined variables +// +// A GraphQL operation is only valid if all variables encountered, both directly +// and via fragment spreads, are defined by that operation. func NoUndefinedVariablesRule(context *ValidationContext) *ValidationRuleInstance { - var operation *ast.OperationDefinition - var visitedFragmentNames = map[string]bool{} - var definedVariableNames = map[string]bool{} + var variableNameDefined = map[string]bool{} + visitorOpts := &visitor.VisitorOptions{ KindFuncMap: map[string]visitor.NamedVisitFuncs{ kinds.OperationDefinition: visitor.NamedVisitFuncs{ - Kind: func(p visitor.VisitFuncParams) (string, interface{}) { - if node, ok := p.Node.(*ast.OperationDefinition); ok && node != nil { - operation = node - visitedFragmentNames = map[string]bool{} - definedVariableNames = map[string]bool{} - } - return visitor.ActionNoChange, nil - }, - }, - kinds.VariableDefinition: visitor.NamedVisitFuncs{ - Kind: func(p visitor.VisitFuncParams) (string, interface{}) { - if node, ok := p.Node.(*ast.VariableDefinition); ok && node != nil { - variableName := "" - if node.Variable != nil && node.Variable.Name != nil { - variableName = node.Variable.Name.Value - } - definedVariableNames[variableName] = true - } + Enter: func(p visitor.VisitFuncParams) (string, interface{}) { + variableNameDefined = map[string]bool{} return visitor.ActionNoChange, nil }, - }, - kinds.Variable: visitor.NamedVisitFuncs{ - Kind: func(p visitor.VisitFuncParams) (string, interface{}) { - if variable, ok := p.Node.(*ast.Variable); ok && variable != nil { - variableName := "" - if variable.Name != nil { - variableName = variable.Name.Value - } - if val, _ := definedVariableNames[variableName]; !val { - withinFragment := false - for _, node := range p.Ancestors { - if node.GetKind() == kinds.FragmentDefinition { - withinFragment = true - break - } + Leave: func(p visitor.VisitFuncParams) (string, interface{}) { + if operation, ok := p.Node.(*ast.OperationDefinition); ok && operation != nil { + usages := context.RecursiveVariableUsages(operation) + + for _, usage := range usages { + if usage == nil { + continue + } + if usage.Node == nil { + continue } - if withinFragment == true && operation != nil && operation.Name != nil { - return newValidationRuleError( - fmt.Sprintf(`Variable "$%v" is not defined by operation "%v".`, variableName, operation.Name.Value), - []ast.Node{variable, operation}, + varName := "" + if usage.Node.Name != nil { + varName = usage.Node.Name.Value + } + opName := "" + if operation.Name != nil { + opName = operation.Name.Value + } + if res, ok := variableNameDefined[varName]; !ok || !res { + reportError( + context, + UndefinedVarMessage(varName, opName), + []ast.Node{usage.Node, operation}, ) } - return newValidationRuleError( - fmt.Sprintf(`Variable "$%v" is not defined.`, variableName), - []ast.Node{variable}, - ) } } return visitor.ActionNoChange, nil }, }, - kinds.FragmentSpread: visitor.NamedVisitFuncs{ + kinds.VariableDefinition: visitor.NamedVisitFuncs{ Kind: func(p visitor.VisitFuncParams) (string, interface{}) { - if node, ok := p.Node.(*ast.FragmentSpread); ok && node != nil { - // Only visit fragments of a particular name once per operation - fragmentName := "" - if node.Name != nil { - fragmentName = node.Name.Value - } - if val, ok := visitedFragmentNames[fragmentName]; ok && val == true { - return visitor.ActionSkip, nil + if node, ok := p.Node.(*ast.VariableDefinition); ok && node != nil { + variableName := "" + if node.Variable != nil && node.Variable.Name != nil { + variableName = node.Variable.Name.Value } - visitedFragmentNames[fragmentName] = true + variableNameDefined[variableName] = true } return visitor.ActionNoChange, nil }, @@ -712,84 +871,51 @@ func NoUndefinedVariablesRule(context *ValidationContext) *ValidationRuleInstanc }, } return &ValidationRuleInstance{ - VisitSpreadFragments: true, - VisitorOpts: visitorOpts, + VisitorOpts: visitorOpts, } } -/** - * NoUnusedFragmentsRule - * No unused fragments - * - * A GraphQL document is only valid if all fragment definitions are spread - * within operations, or spread within other fragments spread within operations. - */ +// NoUnusedFragmentsRule No unused fragments +// +// A GraphQL document is only valid if all fragment definitions are spread +// within operations, or spread within other fragments spread within operations. func NoUnusedFragmentsRule(context *ValidationContext) *ValidationRuleInstance { var fragmentDefs = []*ast.FragmentDefinition{} - var spreadsWithinOperation = []map[string]bool{} - var fragAdjacencies = map[string]map[string]bool{} - var spreadNames = map[string]bool{} + var operationDefs = []*ast.OperationDefinition{} visitorOpts := &visitor.VisitorOptions{ KindFuncMap: map[string]visitor.NamedVisitFuncs{ kinds.OperationDefinition: visitor.NamedVisitFuncs{ Kind: func(p visitor.VisitFuncParams) (string, interface{}) { if node, ok := p.Node.(*ast.OperationDefinition); ok && node != nil { - spreadNames = map[string]bool{} - spreadsWithinOperation = append(spreadsWithinOperation, spreadNames) + operationDefs = append(operationDefs, node) } - return visitor.ActionNoChange, nil + return visitor.ActionSkip, nil }, }, kinds.FragmentDefinition: visitor.NamedVisitFuncs{ Kind: func(p visitor.VisitFuncParams) (string, interface{}) { - if def, ok := p.Node.(*ast.FragmentDefinition); ok && def != nil { - defName := "" - if def.Name != nil { - defName = def.Name.Value - } - - fragmentDefs = append(fragmentDefs, def) - spreadNames = map[string]bool{} - fragAdjacencies[defName] = spreadNames - } - return visitor.ActionNoChange, nil - }, - }, - kinds.FragmentSpread: visitor.NamedVisitFuncs{ - Kind: func(p visitor.VisitFuncParams) (string, interface{}) { - if spread, ok := p.Node.(*ast.FragmentSpread); ok && spread != nil { - spreadName := "" - if spread.Name != nil { - spreadName = spread.Name.Value - } - spreadNames[spreadName] = true + if node, ok := p.Node.(*ast.FragmentDefinition); ok && node != nil { + fragmentDefs = append(fragmentDefs, node) } - return visitor.ActionNoChange, nil + return visitor.ActionSkip, nil }, }, kinds.Document: visitor.NamedVisitFuncs{ Leave: func(p visitor.VisitFuncParams) (string, interface{}) { - - fragmentNameUsed := map[string]interface{}{} - - var reduceSpreadFragments func(spreads map[string]bool) - reduceSpreadFragments = func(spreads map[string]bool) { - for fragName, _ := range spreads { - if isFragNameUsed, _ := fragmentNameUsed[fragName]; isFragNameUsed != true { - fragmentNameUsed[fragName] = true - - if adjacencies, ok := fragAdjacencies[fragName]; ok { - reduceSpreadFragments(adjacencies) - } + fragmentNameUsed := map[string]bool{} + for _, operation := range operationDefs { + fragments := context.RecursivelyReferencedFragments(operation) + for _, fragment := range fragments { + fragName := "" + if fragment.Name != nil { + fragName = fragment.Name.Value } + fragmentNameUsed[fragName] = true } } - for _, spreadWithinOperation := range spreadsWithinOperation { - reduceSpreadFragments(spreadWithinOperation) - } - errors := []error{} + for _, def := range fragmentDefs { defName := "" if def.Name != nil { @@ -798,17 +924,13 @@ func NoUnusedFragmentsRule(context *ValidationContext) *ValidationRuleInstance { isFragNameUsed, ok := fragmentNameUsed[defName] if !ok || isFragNameUsed != true { - _, err := newValidationRuleError( + reportError( + context, fmt.Sprintf(`Fragment "%v" is never used.`, defName), []ast.Node{def}, ) - - errors = append(errors, err) } } - if len(errors) > 0 { - return visitor.ActionNoChange, errors - } return visitor.ActionNoChange, nil }, }, @@ -819,46 +941,62 @@ func NoUnusedFragmentsRule(context *ValidationContext) *ValidationRuleInstance { } } -/** - * NoUnusedVariablesRule - * No unused variables - * - * A GraphQL operation is only valid if all variables defined by an operation - * are used, either directly or within a spread fragment. - */ +func UnusedVariableMessage(varName string, opName string) string { + if opName != "" { + return fmt.Sprintf(`Variable "$%v" is never used in operation "%v".`, varName, opName) + } + return fmt.Sprintf(`Variable "$%v" is never used.`, varName) +} + +// NoUnusedVariablesRule No unused variables +// +// A GraphQL operation is only valid if all variables defined by an operation +// are used, either directly or within a spread fragment. func NoUnusedVariablesRule(context *ValidationContext) *ValidationRuleInstance { - var visitedFragmentNames = map[string]bool{} var variableDefs = []*ast.VariableDefinition{} - var variableNameUsed = map[string]bool{} visitorOpts := &visitor.VisitorOptions{ KindFuncMap: map[string]visitor.NamedVisitFuncs{ kinds.OperationDefinition: visitor.NamedVisitFuncs{ Enter: func(p visitor.VisitFuncParams) (string, interface{}) { - visitedFragmentNames = map[string]bool{} variableDefs = []*ast.VariableDefinition{} - variableNameUsed = map[string]bool{} return visitor.ActionNoChange, nil }, Leave: func(p visitor.VisitFuncParams) (string, interface{}) { - errors := []error{} - for _, def := range variableDefs { - variableName := "" - if def.Variable != nil && def.Variable.Name != nil { - variableName = def.Variable.Name.Value + if operation, ok := p.Node.(*ast.OperationDefinition); ok && operation != nil { + variableNameUsed := map[string]bool{} + usages := context.RecursiveVariableUsages(operation) + + for _, usage := range usages { + varName := "" + if usage != nil && usage.Node != nil && usage.Node.Name != nil { + varName = usage.Node.Name.Value + } + if varName != "" { + variableNameUsed[varName] = true + } } - if isVariableNameUsed, _ := variableNameUsed[variableName]; isVariableNameUsed != true { - _, err := newValidationRuleError( - fmt.Sprintf(`Variable "$%v" is never used.`, variableName), - []ast.Node{def}, - ) - errors = append(errors, err) + for _, variableDef := range variableDefs { + variableName := "" + if variableDef != nil && variableDef.Variable != nil && variableDef.Variable.Name != nil { + variableName = variableDef.Variable.Name.Value + } + opName := "" + if operation.Name != nil { + opName = operation.Name.Value + } + if res, ok := variableNameUsed[variableName]; !ok || !res { + reportError( + context, + UnusedVariableMessage(variableName, opName), + []ast.Node{variableDef}, + ) + } } + } - if len(errors) > 0 { - return visitor.ActionNoChange, errors - } + return visitor.ActionNoChange, nil }, }, @@ -867,48 +1005,20 @@ func NoUnusedVariablesRule(context *ValidationContext) *ValidationRuleInstance { if def, ok := p.Node.(*ast.VariableDefinition); ok && def != nil { variableDefs = append(variableDefs, def) } - // Do not visit deeper, or else the defined variable name will be visited. - return visitor.ActionSkip, nil - }, - }, - kinds.Variable: visitor.NamedVisitFuncs{ - Kind: func(p visitor.VisitFuncParams) (string, interface{}) { - if variable, ok := p.Node.(*ast.Variable); ok && variable != nil { - if variable.Name != nil { - variableNameUsed[variable.Name.Value] = true - } - } - return visitor.ActionNoChange, nil - }, - }, - kinds.FragmentSpread: visitor.NamedVisitFuncs{ - Kind: func(p visitor.VisitFuncParams) (string, interface{}) { - if spreadAST, ok := p.Node.(*ast.FragmentSpread); ok && spreadAST != nil { - // Only visit fragments of a particular name once per operation - spreadName := "" - if spreadAST.Name != nil { - spreadName = spreadAST.Name.Value - } - if hasVisitedFragmentNames, _ := visitedFragmentNames[spreadName]; hasVisitedFragmentNames == true { - return visitor.ActionSkip, nil - } - visitedFragmentNames[spreadName] = true - } return visitor.ActionNoChange, nil }, }, }, } return &ValidationRuleInstance{ - // Visit FragmentDefinition after visiting FragmentSpread - VisitSpreadFragments: true, - VisitorOpts: visitorOpts, + VisitorOpts: visitorOpts, } } type fieldDefPair struct { - Field *ast.Field - FieldDef *FieldDefinition + ParentType Composite + Field *ast.Field + FieldDef *FieldDefinition } func collectFieldASTsAndDefs(context *ValidationContext, parentType Named, selectionSet *ast.SelectionSet, visitedFragmentNames map[string]bool, astAndDefs map[string][]*fieldDefPair) map[string][]*fieldDefPair { @@ -945,15 +1055,27 @@ func collectFieldASTsAndDefs(context *ValidationContext, parentType Named, selec if !ok { astAndDefs[responseName] = []*fieldDefPair{} } - astAndDefs[responseName] = append(astAndDefs[responseName], &fieldDefPair{ - Field: selection, - FieldDef: fieldDef, - }) + if parentType, ok := parentType.(Composite); ok { + astAndDefs[responseName] = append(astAndDefs[responseName], &fieldDefPair{ + ParentType: parentType, + Field: selection, + FieldDef: fieldDef, + }) + } else { + astAndDefs[responseName] = append(astAndDefs[responseName], &fieldDefPair{ + Field: selection, + FieldDef: fieldDef, + }) + } case *ast.InlineFragment: - parentType, _ := typeFromAST(*context.Schema(), selection.TypeCondition) + inlineFragmentType := parentType + if selection.TypeCondition != nil { + parentType, _ := typeFromAST(*context.Schema(), selection.TypeCondition) + inlineFragmentType = parentType + } astAndDefs = collectFieldASTsAndDefs( context, - parentType, + inlineFragmentType, selection.SelectionSet, visitedFragmentNames, astAndDefs, @@ -984,10 +1106,8 @@ func collectFieldASTsAndDefs(context *ValidationContext, parentType Named, selec return astAndDefs } -/** - * pairSet A way to keep track of pairs of things when the ordering of the pair does - * not matter. We do this by maintaining a sort of double adjacency sets. - */ +// pairSet A way to keep track of pairs of things when the ordering of the pair does +// not matter. We do this by maintaining a sort of double adjacency sets. type pairSet struct { data map[ast.Node]*nodeSet } @@ -1026,41 +1146,11 @@ type conflictReason struct { Message interface{} // conflictReason || []conflictReason } type conflict struct { - Reason conflictReason - Fields []ast.Node + Reason conflictReason + FieldsLeft []ast.Node + FieldsRight []ast.Node } -func sameDirectives(directives1 []*ast.Directive, directives2 []*ast.Directive) bool { - if len(directives1) != len(directives1) { - return false - } - for _, directive1 := range directives1 { - directive1Name := "" - if directive1.Name != nil { - directive1Name = directive1.Name.Value - } - - var foundDirective2 *ast.Directive - for _, directive2 := range directives2 { - directive2Name := "" - if directive2.Name != nil { - directive2Name = directive2.Name.Value - } - if directive1Name == directive2Name { - foundDirective2 = directive2 - } - break - } - if foundDirective2 == nil { - return false - } - if sameArguments(directive1.Arguments, foundDirective2.Arguments) == false { - return false - } - } - - return true -} func sameArguments(args1 []*ast.Argument, args2 []*ast.Argument) bool { if len(args1) != len(args2) { return false @@ -1102,33 +1192,67 @@ func sameValue(value1 ast.Value, value2 ast.Value) bool { return val1 == val2 } -func sameType(type1 Type, type2 Type) bool { - t := fmt.Sprintf("%v", type1) - t2 := fmt.Sprintf("%v", type2) - return t == t2 + +func sameType(typeA, typeB Type) bool { + if typeA == typeB { + return true + } + + if typeA, ok := typeA.(*List); ok { + if typeB, ok := typeB.(*List); ok { + return sameType(typeA.OfType, typeB.OfType) + } + } + if typeA, ok := typeA.(*NonNull); ok { + if typeB, ok := typeB.(*NonNull); ok { + return sameType(typeA.OfType, typeB.OfType) + } + } + + return false } -/** - * OverlappingFieldsCanBeMergedRule - * Overlapping fields can be merged - * - * A selection set is only valid if all fields (including spreading any - * fragments) either correspond to distinct response names or can be merged - * without ambiguity. - */ +// OverlappingFieldsCanBeMergedRule Overlapping fields can be merged +// +// A selection set is only valid if all fields (including spreading any +// fragments) either correspond to distinct response names or can be merged +// without ambiguity. func OverlappingFieldsCanBeMergedRule(context *ValidationContext) *ValidationRuleInstance { comparedSet := newPairSet() var findConflicts func(fieldMap map[string][]*fieldDefPair) (conflicts []*conflict) - findConflict := func(responseName string, pair *fieldDefPair, pair2 *fieldDefPair) *conflict { + findConflict := func(responseName string, field *fieldDefPair, field2 *fieldDefPair) *conflict { - ast1 := pair.Field - def1 := pair.FieldDef + parentType1 := field.ParentType + ast1 := field.Field + def1 := field.FieldDef - ast2 := pair2.Field - def2 := pair2.FieldDef + parentType2 := field2.ParentType + ast2 := field2.Field + def2 := field2.FieldDef + + // Not a pair. + if ast1 == ast2 { + return nil + } + + // If the statically known parent types could not possibly apply at the same + // time, then it is safe to permit them to diverge as they will not present + // any ambiguity by differing. + // It is known that two parent types could never overlap if they are + // different Object types. Interface or Union types might overlap - if not + // in the current state of the schema, then perhaps in some future version, + // thus may not safely diverge. + if parentType1 != parentType2 { + _, ok1 := parentType1.(*Object) + _, ok2 := parentType2.(*Object) + if ok1 && ok2 { + return nil + } + } - if ast1 == ast2 || comparedSet.Has(ast1, ast2) { + // Memoize, do not report the same issue twice. + if comparedSet.Has(ast1, ast2) { return nil } comparedSet.Add(ast1, ast2) @@ -1147,7 +1271,8 @@ func OverlappingFieldsCanBeMergedRule(context *ValidationContext) *ValidationRul Name: responseName, Message: fmt.Sprintf(`%v and %v are different fields`, name1, name2), }, - Fields: []ast.Node{ast1, ast2}, + FieldsLeft: []ast.Node{ast1}, + FieldsRight: []ast.Node{ast2}, } } @@ -1166,7 +1291,8 @@ func OverlappingFieldsCanBeMergedRule(context *ValidationContext) *ValidationRul Name: responseName, Message: fmt.Sprintf(`they return differing types %v and %v`, type1, type2), }, - Fields: []ast.Node{ast1, ast2}, + FieldsLeft: []ast.Node{ast1}, + FieldsRight: []ast.Node{ast2}, } } if !sameArguments(ast1.Arguments, ast2.Arguments) { @@ -1175,16 +1301,8 @@ func OverlappingFieldsCanBeMergedRule(context *ValidationContext) *ValidationRul Name: responseName, Message: `they have differing arguments`, }, - Fields: []ast.Node{ast1, ast2}, - } - } - if !sameDirectives(ast1.Directives, ast2.Directives) { - return &conflict{ - Reason: conflictReason{ - Name: responseName, - Message: `they have differing directives`, - }, - Fields: []ast.Node{ast1, ast2}, + FieldsLeft: []ast.Node{ast1}, + FieldsRight: []ast.Node{ast2}, } } @@ -1210,10 +1328,12 @@ func OverlappingFieldsCanBeMergedRule(context *ValidationContext) *ValidationRul if len(conflicts) > 0 { conflictReasons := []conflictReason{} - conflictFields := []ast.Node{ast1, ast2} + conflictFieldsLeft := []ast.Node{ast1} + conflictFieldsRight := []ast.Node{ast2} for _, c := range conflicts { conflictReasons = append(conflictReasons, c.Reason) - conflictFields = append(conflictFields, c.Fields...) + conflictFieldsLeft = append(conflictFieldsLeft, c.FieldsLeft...) + conflictFieldsRight = append(conflictFieldsRight, c.FieldsRight...) } return &conflict{ @@ -1221,7 +1341,8 @@ func OverlappingFieldsCanBeMergedRule(context *ValidationContext) *ValidationRul Name: responseName, Message: conflictReasons, }, - Fields: conflictFields, + FieldsLeft: conflictFieldsLeft, + FieldsRight: conflictFieldsRight, } } } @@ -1232,7 +1353,7 @@ func OverlappingFieldsCanBeMergedRule(context *ValidationContext) *ValidationRul // ensure field traversal orderedName := sort.StringSlice{} - for responseName, _ := range fieldMap { + for responseName := range fieldMap { orderedName = append(orderedName, responseName) } orderedName.Sort() @@ -1287,22 +1408,20 @@ func OverlappingFieldsCanBeMergedRule(context *ValidationContext) *ValidationRul ) conflicts := findConflicts(fieldMap) if len(conflicts) > 0 { - errors := []error{} for _, c := range conflicts { responseName := c.Reason.Name reason := c.Reason - _, err := newValidationRuleError( + reportError( + context, fmt.Sprintf( `Fields "%v" conflict because %v.`, responseName, reasonMessage(reason), ), - c.Fields, + append(c.FieldsLeft, c.FieldsRight...), ) - errors = append(errors, err) - } - return visitor.ActionNoChange, errors + return visitor.ActionNoChange, nil } } return visitor.ActionNoChange, nil @@ -1366,14 +1485,11 @@ func doTypesOverlap(t1 Type, t2 Type) bool { return false } -/** - * PossibleFragmentSpreadsRule - * Possible fragment spread - * - * A fragment spread is only valid if the type condition could ever possibly - * be true: if there is a non-empty intersection of the possible parent types, - * and possible types which pass the type condition. - */ +// PossibleFragmentSpreadsRule Possible fragment spread +// +// A fragment spread is only valid if the type condition could ever possibly +// be true: if there is a non-empty intersection of the possible parent types, +// and possible types which pass the type condition. func PossibleFragmentSpreadsRule(context *ValidationContext) *ValidationRuleInstance { visitorOpts := &visitor.VisitorOptions{ @@ -1385,7 +1501,8 @@ func PossibleFragmentSpreadsRule(context *ValidationContext) *ValidationRuleInst parentType, _ := context.ParentType().(Type) if fragType != nil && parentType != nil && !doTypesOverlap(fragType, parentType) { - return newValidationRuleError( + reportError( + context, fmt.Sprintf(`Fragment cannot be spread here as objects of `+ `type "%v" can never be of type "%v".`, parentType, fragType), []ast.Node{node}, @@ -1405,7 +1522,8 @@ func PossibleFragmentSpreadsRule(context *ValidationContext) *ValidationRuleInst fragType := getFragmentType(context, fragName) parentType, _ := context.ParentType().(Type) if fragType != nil && parentType != nil && !doTypesOverlap(fragType, parentType) { - return newValidationRuleError( + reportError( + context, fmt.Sprintf(`Fragment "%v" cannot be spread here as objects of `+ `type "%v" can never be of type "%v".`, fragName, parentType, fragType), []ast.Node{node}, @@ -1422,13 +1540,10 @@ func PossibleFragmentSpreadsRule(context *ValidationContext) *ValidationRuleInst } } -/** - * ProvidedNonNullArgumentsRule - * Provided required arguments - * - * A field or directive is only valid if all required (non-null) field arguments - * have been provided. - */ +// ProvidedNonNullArgumentsRule Provided required arguments +// +// A field or directive is only valid if all required (non-null) field arguments +// have been provided. func ProvidedNonNullArgumentsRule(context *ValidationContext) *ValidationRuleInstance { visitorOpts := &visitor.VisitorOptions{ @@ -1442,7 +1557,6 @@ func ProvidedNonNullArgumentsRule(context *ValidationContext) *ValidationRuleIns return visitor.ActionSkip, nil } - errors := []error{} argASTs := fieldAST.Arguments argASTMap := map[string]*ast.Argument{} @@ -1461,18 +1575,15 @@ func ProvidedNonNullArgumentsRule(context *ValidationContext) *ValidationRuleIns if fieldAST.Name != nil { fieldName = fieldAST.Name.Value } - _, err := newValidationRuleError( + reportError( + context, fmt.Sprintf(`Field "%v" argument "%v" of type "%v" `+ `is required but not provided.`, fieldName, argDef.Name(), argDefType), []ast.Node{fieldAST}, ) - errors = append(errors, err) } } } - if len(errors) > 0 { - return visitor.ActionNoChange, errors - } } return visitor.ActionNoChange, nil }, @@ -1486,7 +1597,6 @@ func ProvidedNonNullArgumentsRule(context *ValidationContext) *ValidationRuleIns if directiveDef == nil { return visitor.ActionSkip, nil } - errors := []error{} argASTs := directiveAST.Arguments argASTMap := map[string]*ast.Argument{} @@ -1506,18 +1616,15 @@ func ProvidedNonNullArgumentsRule(context *ValidationContext) *ValidationRuleIns if directiveAST.Name != nil { directiveName = directiveAST.Name.Value } - _, err := newValidationRuleError( + reportError( + context, fmt.Sprintf(`Directive "@%v" argument "%v" of type `+ `"%v" is required but not provided.`, directiveName, argDef.Name(), argDefType), []ast.Node{directiveAST}, ) - errors = append(errors, err) } } } - if len(errors) > 0 { - return visitor.ActionNoChange, errors - } } return visitor.ActionNoChange, nil }, @@ -1529,13 +1636,10 @@ func ProvidedNonNullArgumentsRule(context *ValidationContext) *ValidationRuleIns } } -/** - * ScalarLeafsRule - * Scalar leafs - * - * A GraphQL document is valid only if all leaf fields (fields without - * sub selections) are of scalar or enum types. - */ +// ScalarLeafsRule Scalar leafs +// +// A GraphQL document is valid only if all leaf fields (fields without +// sub selections) are of scalar or enum types. func ScalarLeafsRule(context *ValidationContext) *ValidationRuleInstance { visitorOpts := &visitor.VisitorOptions{ @@ -1551,13 +1655,15 @@ func ScalarLeafsRule(context *ValidationContext) *ValidationRuleInstance { if ttype != nil { if IsLeafType(ttype) { if node.SelectionSet != nil { - return newValidationRuleError( + reportError( + context, fmt.Sprintf(`Field "%v" of type "%v" must not have a sub selection.`, nodeName, ttype), []ast.Node{node.SelectionSet}, ) } } else if node.SelectionSet == nil { - return newValidationRuleError( + reportError( + context, fmt.Sprintf(`Field "%v" of type "%v" must have a sub selection.`, nodeName, ttype), []ast.Node{node}, ) @@ -1574,13 +1680,10 @@ func ScalarLeafsRule(context *ValidationContext) *ValidationRuleInstance { } } -/** - * UniqueArgumentNamesRule - * Unique argument names - * - * A GraphQL field or directive is only valid if all supplied arguments are - * uniquely named. - */ +// UniqueArgumentNamesRule Unique argument names +// +// A GraphQL field or directive is only valid if all supplied arguments are +// uniquely named. func UniqueArgumentNamesRule(context *ValidationContext) *ValidationRuleInstance { knownArgNames := map[string]*ast.Name{} @@ -1606,14 +1709,16 @@ func UniqueArgumentNamesRule(context *ValidationContext) *ValidationRuleInstance argName = node.Name.Value } if nameAST, ok := knownArgNames[argName]; ok { - return newValidationRuleError( + reportError( + context, fmt.Sprintf(`There can be only one argument named "%v".`, argName), []ast.Node{nameAST, node.Name}, ) + } else { + knownArgNames[argName] = node.Name } - knownArgNames[argName] = node.Name } - return visitor.ActionNoChange, nil + return visitor.ActionSkip, nil }, }, }, @@ -1623,17 +1728,19 @@ func UniqueArgumentNamesRule(context *ValidationContext) *ValidationRuleInstance } } -/** - * UniqueFragmentNamesRule - * Unique fragment names - * - * A GraphQL document is only valid if all defined fragments have unique names. - */ +// UniqueFragmentNamesRule Unique fragment names +// +// A GraphQL document is only valid if all defined fragments have unique names. func UniqueFragmentNamesRule(context *ValidationContext) *ValidationRuleInstance { knownFragmentNames := map[string]*ast.Name{} visitorOpts := &visitor.VisitorOptions{ KindFuncMap: map[string]visitor.NamedVisitFuncs{ + kinds.OperationDefinition: visitor.NamedVisitFuncs{ + Kind: func(p visitor.VisitFuncParams) (string, interface{}) { + return visitor.ActionSkip, nil + }, + }, kinds.FragmentDefinition: visitor.NamedVisitFuncs{ Kind: func(p visitor.VisitFuncParams) (string, interface{}) { if node, ok := p.Node.(*ast.FragmentDefinition); ok && node != nil { @@ -1642,16 +1749,68 @@ func UniqueFragmentNamesRule(context *ValidationContext) *ValidationRuleInstance fragmentName = node.Name.Value } if nameAST, ok := knownFragmentNames[fragmentName]; ok { - return newValidationRuleError( + reportError( + context, fmt.Sprintf(`There can only be one fragment named "%v".`, fragmentName), []ast.Node{nameAST, node.Name}, ) + } else { + knownFragmentNames[fragmentName] = node.Name } - knownFragmentNames[fragmentName] = node.Name } + return visitor.ActionSkip, nil + }, + }, + }, + } + return &ValidationRuleInstance{ + VisitorOpts: visitorOpts, + } +} + +// UniqueInputFieldNamesRule Unique input field names +// +// A GraphQL input object value is only valid if all supplied fields are +// uniquely named. +func UniqueInputFieldNamesRule(context *ValidationContext) *ValidationRuleInstance { + knownNameStack := []map[string]*ast.Name{} + knownNames := map[string]*ast.Name{} + + visitorOpts := &visitor.VisitorOptions{ + KindFuncMap: map[string]visitor.NamedVisitFuncs{ + kinds.ObjectValue: visitor.NamedVisitFuncs{ + Enter: func(p visitor.VisitFuncParams) (string, interface{}) { + knownNameStack = append(knownNameStack, knownNames) + knownNames = map[string]*ast.Name{} + return visitor.ActionNoChange, nil + }, + Leave: func(p visitor.VisitFuncParams) (string, interface{}) { + // pop + knownNames, knownNameStack = knownNameStack[len(knownNameStack)-1], knownNameStack[:len(knownNameStack)-1] return visitor.ActionNoChange, nil }, }, + kinds.ObjectField: visitor.NamedVisitFuncs{ + Kind: func(p visitor.VisitFuncParams) (string, interface{}) { + if node, ok := p.Node.(*ast.ObjectField); ok { + fieldName := "" + if node.Name != nil { + fieldName = node.Name.Value + } + if knownNameAST, ok := knownNames[fieldName]; ok { + reportError( + context, + fmt.Sprintf(`There can be only one input field named "%v".`, fieldName), + []ast.Node{knownNameAST, node.Name}, + ) + } else { + knownNames[fieldName] = node.Name + } + + } + return visitor.ActionSkip, nil + }, + }, }, } return &ValidationRuleInstance{ @@ -1659,12 +1818,9 @@ func UniqueFragmentNamesRule(context *ValidationContext) *ValidationRuleInstance } } -/** - * UniqueOperationNamesRule - * Unique operation names - * - * A GraphQL document is only valid if all defined operations have unique names. - */ +// UniqueOperationNamesRule Unique operation names +// +// A GraphQL document is only valid if all defined operations have unique names. func UniqueOperationNamesRule(context *ValidationContext) *ValidationRuleInstance { knownOperationNames := map[string]*ast.Name{} @@ -1678,12 +1834,64 @@ func UniqueOperationNamesRule(context *ValidationContext) *ValidationRuleInstanc operationName = node.Name.Value } if nameAST, ok := knownOperationNames[operationName]; ok { - return newValidationRuleError( + reportError( + context, fmt.Sprintf(`There can only be one operation named "%v".`, operationName), []ast.Node{nameAST, node.Name}, ) + } else { + knownOperationNames[operationName] = node.Name + } + } + return visitor.ActionSkip, nil + }, + }, + kinds.FragmentDefinition: visitor.NamedVisitFuncs{ + Kind: func(p visitor.VisitFuncParams) (string, interface{}) { + return visitor.ActionSkip, nil + }, + }, + }, + } + return &ValidationRuleInstance{ + VisitorOpts: visitorOpts, + } +} + +// UniqueVariableNamesRule Unique variable names +// +// A GraphQL operation is only valid if all its variables are uniquely named. +func UniqueVariableNamesRule(context *ValidationContext) *ValidationRuleInstance { + knownVariableNames := map[string]*ast.Name{} + + visitorOpts := &visitor.VisitorOptions{ + KindFuncMap: map[string]visitor.NamedVisitFuncs{ + kinds.OperationDefinition: visitor.NamedVisitFuncs{ + Kind: func(p visitor.VisitFuncParams) (string, interface{}) { + if node, ok := p.Node.(*ast.OperationDefinition); ok && node != nil { + knownVariableNames = map[string]*ast.Name{} + } + return visitor.ActionNoChange, nil + }, + }, + kinds.VariableDefinition: visitor.NamedVisitFuncs{ + Kind: func(p visitor.VisitFuncParams) (string, interface{}) { + if node, ok := p.Node.(*ast.VariableDefinition); ok && node != nil { + variableName := "" + var variableNameAST *ast.Name + if node.Variable != nil && node.Variable.Name != nil { + variableNameAST = node.Variable.Name + variableName = node.Variable.Name.Value + } + if nameAST, ok := knownVariableNames[variableName]; ok { + reportError( + context, + fmt.Sprintf(`There can only be one variable named "%v".`, variableName), + []ast.Node{nameAST, variableNameAST}, + ) + } else { + knownVariableNames[variableName] = variableNameAST } - knownOperationNames[operationName] = node.Name } return visitor.ActionNoChange, nil }, @@ -1695,13 +1903,10 @@ func UniqueOperationNamesRule(context *ValidationContext) *ValidationRuleInstanc } } -/** - * VariablesAreInputTypesRule - * Variables are input types - * - * A GraphQL operation is only valid if all the variables it defines are of - * input types (scalar, enum, or input object). - */ +// VariablesAreInputTypesRule Variables are input types +// +// A GraphQL operation is only valid if all the variables it defines are of +// input types (scalar, enum, or input object). func VariablesAreInputTypesRule(context *ValidationContext) *ValidationRuleInstance { visitorOpts := &visitor.VisitorOptions{ @@ -1717,7 +1922,8 @@ func VariablesAreInputTypesRule(context *ValidationContext) *ValidationRuleInsta if node.Variable != nil && node.Variable.Name != nil { variableName = node.Variable.Name.Value } - return newValidationRuleError( + reportError( + context, fmt.Sprintf(`Variable "$%v" cannot be non-input type "%v".`, variableName, printer.Print(node.Type)), []ast.Node{node.Type}, @@ -1745,43 +1951,45 @@ func effectiveType(varType Type, varDef *ast.VariableDefinition) Type { return NewNonNull(varType) } -// A var type is allowed if it is the same or more strict than the expected -// type. It can be more strict if the variable type is non-null when the -// expected type is nullable. If both are list types, the variable item type can -// be more strict than the expected item type. -func varTypeAllowedForType(varType Type, expectedType Type) bool { - if expectedType, ok := expectedType.(*NonNull); ok { - if varType, ok := varType.(*NonNull); ok { - return varTypeAllowedForType(varType.OfType, expectedType.OfType) - } - return false - } - if varType, ok := varType.(*NonNull); ok { - return varTypeAllowedForType(varType.OfType, expectedType) - } - if varType, ok := varType.(*List); ok { - if expectedType, ok := expectedType.(*List); ok { - return varTypeAllowedForType(varType.OfType, expectedType.OfType) - } - } - return varType == expectedType -} - -/** - * VariablesInAllowedPositionRule - * Variables passed to field arguments conform to type - */ +// VariablesInAllowedPositionRule Variables passed to field arguments conform to type func VariablesInAllowedPositionRule(context *ValidationContext) *ValidationRuleInstance { varDefMap := map[string]*ast.VariableDefinition{} - visitedFragmentNames := map[string]bool{} visitorOpts := &visitor.VisitorOptions{ KindFuncMap: map[string]visitor.NamedVisitFuncs{ kinds.OperationDefinition: visitor.NamedVisitFuncs{ - Kind: func(p visitor.VisitFuncParams) (string, interface{}) { + Enter: func(p visitor.VisitFuncParams) (string, interface{}) { varDefMap = map[string]*ast.VariableDefinition{} - visitedFragmentNames = map[string]bool{} + return visitor.ActionNoChange, nil + }, + Leave: func(p visitor.VisitFuncParams) (string, interface{}) { + if operation, ok := p.Node.(*ast.OperationDefinition); ok { + + usages := context.RecursiveVariableUsages(operation) + for _, usage := range usages { + varName := "" + if usage != nil && usage.Node != nil && usage.Node.Name != nil { + varName = usage.Node.Name.Value + } + varDef, _ := varDefMap[varName] + if varDef != nil && usage.Type != nil { + varType, err := typeFromAST(*context.Schema(), varDef.Type) + if err != nil { + varType = nil + } + if varType != nil && !isTypeSubTypeOf(effectiveType(varType, varDef), usage.Type) { + reportError( + context, + fmt.Sprintf(`Variable "$%v" of type "%v" used in position `+ + `expecting type "%v".`, varName, varType, usage.Type), + []ast.Node{varDef, usage.Node}, + ) + } + } + } + + } return visitor.ActionNoChange, nil }, }, @@ -1792,46 +2000,8 @@ func VariablesInAllowedPositionRule(context *ValidationContext) *ValidationRuleI if varDefAST.Variable != nil && varDefAST.Variable.Name != nil { defName = varDefAST.Variable.Name.Value } - varDefMap[defName] = varDefAST - } - return visitor.ActionNoChange, nil - }, - }, - kinds.FragmentSpread: visitor.NamedVisitFuncs{ - Kind: func(p visitor.VisitFuncParams) (string, interface{}) { - // Only visit fragments of a particular name once per operation - if spreadAST, ok := p.Node.(*ast.FragmentSpread); ok { - spreadName := "" - if spreadAST.Name != nil { - spreadName = spreadAST.Name.Value - } - if hasVisited, _ := visitedFragmentNames[spreadName]; hasVisited { - return visitor.ActionSkip, nil - } - visitedFragmentNames[spreadName] = true - } - return visitor.ActionNoChange, nil - }, - }, - kinds.Variable: visitor.NamedVisitFuncs{ - Kind: func(p visitor.VisitFuncParams) (string, interface{}) { - if variableAST, ok := p.Node.(*ast.Variable); ok && variableAST != nil { - varName := "" - if variableAST.Name != nil { - varName = variableAST.Name.Value - } - varDef, _ := varDefMap[varName] - var varType Type - if varDef != nil { - varType, _ = typeFromAST(*context.Schema(), varDef.Type) - } - inputType := context.InputType() - if varType != nil && inputType != nil && !varTypeAllowedForType(effectiveType(varType, varDef), inputType) { - return newValidationRuleError( - fmt.Sprintf(`Variable "$%v" of type "%v" used in position `+ - `expecting type "%v".`, varName, varType, inputType), - []ast.Node{variableAST}, - ) + if defName != "" { + varDefMap[defName] = varDefAST } } return visitor.ActionNoChange, nil @@ -1840,48 +2010,50 @@ func VariablesInAllowedPositionRule(context *ValidationContext) *ValidationRuleI }, } return &ValidationRuleInstance{ - VisitSpreadFragments: true, - VisitorOpts: visitorOpts, + VisitorOpts: visitorOpts, } } -/** - * Utility for validators which determines if a value literal AST is valid given - * an input type. - * - * Note that this only validates literal values, variables are assumed to - * provide values of the correct type. - */ -func isValidLiteralValue(ttype Input, valueAST ast.Value) bool { +// Utility for validators which determines if a value literal AST is valid given +// an input type. +// +// Note that this only validates literal values, variables are assumed to +// provide values of the correct type. +func isValidLiteralValue(ttype Input, valueAST ast.Value) (bool, []string) { // A value must be provided if the type is non-null. if ttype, ok := ttype.(*NonNull); ok { if valueAST == nil { - return false + if ttype.OfType.Name() != "" { + return false, []string{fmt.Sprintf(`Expected "%v!", found null.`, ttype.OfType.Name())} + } + return false, []string{"Expected non-null value, found null."} } ofType, _ := ttype.OfType.(Input) return isValidLiteralValue(ofType, valueAST) } if valueAST == nil { - return true + return true, nil } // This function only tests literals, and assumes variables will provide // values of the correct type. if valueAST.GetKind() == kinds.Variable { - return true + return true, nil } // Lists accept a non-list value as a list of one. if ttype, ok := ttype.(*List); ok { itemType, _ := ttype.OfType.(Input) if valueAST, ok := valueAST.(*ast.ListValue); ok { + messagesReduce := []string{} for _, value := range valueAST.Values { - if isValidLiteralValue(itemType, value) == false { - return false + _, messages := isValidLiteralValue(itemType, value) + for idx, message := range messages { + messagesReduce = append(messagesReduce, fmt.Sprintf(`In element #%v: %v`, idx+1, message)) } } - return true + return (len(messagesReduce) == 0), messagesReduce } return isValidLiteralValue(itemType, valueAST) @@ -1891,12 +2063,12 @@ func isValidLiteralValue(ttype Input, valueAST ast.Value) bool { if ttype, ok := ttype.(*InputObject); ok { valueAST, ok := valueAST.(*ast.ObjectValue) if !ok { - return false + return false, []string{fmt.Sprintf(`Expected "%v", found not an object.`, ttype.Name())} } fields := ttype.Fields() + messagesReduce := []string{} // Ensure every provided field is defined. - // Ensure every defined field is valid. fieldASTs := valueAST.Fields fieldASTMap := map[string]*ast.ObjectField{} for _, fieldAST := range fieldASTs { @@ -1907,55 +2079,37 @@ func isValidLiteralValue(ttype Input, valueAST ast.Value) bool { fieldASTMap[fieldASTName] = fieldAST - // check if field is defined field, ok := fields[fieldASTName] if !ok || field == nil { - return false + messagesReduce = append(messagesReduce, fmt.Sprintf(`In field "%v": Unknown field.`, fieldASTName)) } } + // Ensure every defined field is valid. for fieldName, field := range fields { fieldAST, _ := fieldASTMap[fieldName] var fieldASTValue ast.Value if fieldAST != nil { fieldASTValue = fieldAST.Value } - if !isValidLiteralValue(field.Type, fieldASTValue) { - return false + if isValid, messages := isValidLiteralValue(field.Type, fieldASTValue); !isValid { + for _, message := range messages { + messagesReduce = append(messagesReduce, fmt.Sprintf("In field \"%v\": %v", fieldName, message)) + } } } - return true + return (len(messagesReduce) == 0), messagesReduce } if ttype, ok := ttype.(*Scalar); ok { - return !isNullish(ttype.ParseLiteral(valueAST)) + if isNullish(ttype.ParseLiteral(valueAST)) { + return false, []string{fmt.Sprintf(`Expected type "%v", found %v.`, ttype.Name(), printer.Print(valueAST))} + } } if ttype, ok := ttype.(*Enum); ok { - return !isNullish(ttype.ParseLiteral(valueAST)) + if isNullish(ttype.ParseLiteral(valueAST)) { + return false, []string{fmt.Sprintf(`Expected type "%v", found %v.`, ttype.Name(), printer.Print(valueAST))} + } } - // Must be input type (not scalar or enum) - // Silently fail, instead of panic() - return false -} - -/** - * Given an operation or fragment AST node, gather all the - * named spreads defined within the scope of the fragment - * or operation - */ -func gatherSpreads(node ast.Node) (spreadNodes []*ast.FragmentSpread) { - visitorOpts := &visitor.VisitorOptions{ - KindFuncMap: map[string]visitor.NamedVisitFuncs{ - kinds.FragmentSpread: visitor.NamedVisitFuncs{ - Kind: func(p visitor.VisitFuncParams) (string, interface{}) { - if node, ok := p.Node.(*ast.FragmentSpread); ok && node != nil { - spreadNodes = append(spreadNodes, node) - } - return visitor.ActionNoChange, nil - }, - }, - }, - } - visitor.Visit(node, visitorOpts, nil) - return spreadNodes + return true, nil } diff --git a/rules_arguments_of_correct_type_test.go b/rules_arguments_of_correct_type_test.go index 27a2443b..ecd4bea4 100644 --- a/rules_arguments_of_correct_type_test.go +++ b/rules_arguments_of_correct_type_test.go @@ -91,7 +91,7 @@ func TestValidate_ArgValuesOfCorrectType_InvalidStringValues_IntIntoString(t *te `, []gqlerrors.FormattedError{ testutil.RuleError( - `Argument "stringArg" expected type "String" but got: 1.`, + "Argument \"stringArg\" has invalid value 1.\nExpected type \"String\", found 1.", 4, 39, ), }) @@ -106,7 +106,7 @@ func TestValidate_ArgValuesOfCorrectType_InvalidStringValues_FloatIntoString(t * `, []gqlerrors.FormattedError{ testutil.RuleError( - `Argument "stringArg" expected type "String" but got: 1.0.`, + "Argument \"stringArg\" has invalid value 1.0.\nExpected type \"String\", found 1.0.", 4, 39, ), }) @@ -121,7 +121,7 @@ func TestValidate_ArgValuesOfCorrectType_InvalidStringValues_BooleanIntoString(t `, []gqlerrors.FormattedError{ testutil.RuleError( - `Argument "stringArg" expected type "String" but got: true.`, + "Argument \"stringArg\" has invalid value true.\nExpected type \"String\", found true.", 4, 39, ), }) @@ -136,7 +136,7 @@ func TestValidate_ArgValuesOfCorrectType_InvalidStringValues_UnquotedStringIntoS `, []gqlerrors.FormattedError{ testutil.RuleError( - `Argument "stringArg" expected type "String" but got: BAR.`, + "Argument \"stringArg\" has invalid value BAR.\nExpected type \"String\", found BAR.", 4, 39, ), }) @@ -152,7 +152,7 @@ func TestValidate_ArgValuesOfCorrectType_InvalidIntValues_StringIntoInt(t *testi `, []gqlerrors.FormattedError{ testutil.RuleError( - `Argument "intArg" expected type "Int" but got: "3".`, + "Argument \"intArg\" has invalid value \"3\".\nExpected type \"Int\", found \"3\".", 4, 33, ), }) @@ -167,7 +167,7 @@ func TestValidate_ArgValuesOfCorrectType_InvalidIntValues_BigIntIntoInt(t *testi `, []gqlerrors.FormattedError{ testutil.RuleError( - `Argument "intArg" expected type "Int" but got: 829384293849283498239482938.`, + "Argument \"intArg\" has invalid value 829384293849283498239482938.\nExpected type \"Int\", found 829384293849283498239482938.", 4, 33, ), }) @@ -182,7 +182,7 @@ func TestValidate_ArgValuesOfCorrectType_InvalidIntValues_UnquotedStringIntoInt( `, []gqlerrors.FormattedError{ testutil.RuleError( - `Argument "intArg" expected type "Int" but got: FOO.`, + "Argument \"intArg\" has invalid value FOO.\nExpected type \"Int\", found FOO.", 4, 33, ), }) @@ -197,7 +197,7 @@ func TestValidate_ArgValuesOfCorrectType_InvalidIntValues_SimpleFloatIntoInt(t * `, []gqlerrors.FormattedError{ testutil.RuleError( - `Argument "intArg" expected type "Int" but got: 3.0.`, + "Argument \"intArg\" has invalid value 3.0.\nExpected type \"Int\", found 3.0.", 4, 33, ), }) @@ -212,7 +212,7 @@ func TestValidate_ArgValuesOfCorrectType_InvalidIntValues_FloatIntoInt(t *testin `, []gqlerrors.FormattedError{ testutil.RuleError( - `Argument "intArg" expected type "Int" but got: 3.333.`, + "Argument \"intArg\" has invalid value 3.333.\nExpected type \"Int\", found 3.333.", 4, 33, ), }) @@ -228,7 +228,7 @@ func TestValidate_ArgValuesOfCorrectType_InvalidFloatValues_StringIntoFloat(t *t `, []gqlerrors.FormattedError{ testutil.RuleError( - `Argument "floatArg" expected type "Float" but got: "3.333".`, + "Argument \"floatArg\" has invalid value \"3.333\".\nExpected type \"Float\", found \"3.333\".", 4, 37, ), }) @@ -243,7 +243,7 @@ func TestValidate_ArgValuesOfCorrectType_InvalidFloatValues_BooleanIntoFloat(t * `, []gqlerrors.FormattedError{ testutil.RuleError( - `Argument "floatArg" expected type "Float" but got: true.`, + "Argument \"floatArg\" has invalid value true.\nExpected type \"Float\", found true.", 4, 37, ), }) @@ -258,7 +258,7 @@ func TestValidate_ArgValuesOfCorrectType_InvalidFloatValues_UnquotedIntoFloat(t `, []gqlerrors.FormattedError{ testutil.RuleError( - `Argument "floatArg" expected type "Float" but got: FOO.`, + "Argument \"floatArg\" has invalid value FOO.\nExpected type \"Float\", found FOO.", 4, 37, ), }) @@ -274,7 +274,7 @@ func TestValidate_ArgValuesOfCorrectType_InvalidBooleanValues_IntIntoBoolean(t * `, []gqlerrors.FormattedError{ testutil.RuleError( - `Argument "booleanArg" expected type "Boolean" but got: 2.`, + "Argument \"booleanArg\" has invalid value 2.\nExpected type \"Boolean\", found 2.", 4, 41, ), }) @@ -289,7 +289,7 @@ func TestValidate_ArgValuesOfCorrectType_InvalidBooleanValues_FloatIntoBoolean(t `, []gqlerrors.FormattedError{ testutil.RuleError( - `Argument "booleanArg" expected type "Boolean" but got: 1.0.`, + "Argument \"booleanArg\" has invalid value 1.0.\nExpected type \"Boolean\", found 1.0.", 4, 41, ), }) @@ -304,7 +304,7 @@ func TestValidate_ArgValuesOfCorrectType_InvalidBooleanValues_StringIntoBoolean( `, []gqlerrors.FormattedError{ testutil.RuleError( - `Argument "booleanArg" expected type "Boolean" but got: "true".`, + "Argument \"booleanArg\" has invalid value \"true\".\nExpected type \"Boolean\", found \"true\".", 4, 41, ), }) @@ -319,7 +319,7 @@ func TestValidate_ArgValuesOfCorrectType_InvalidBooleanValues_UnquotedStringInto `, []gqlerrors.FormattedError{ testutil.RuleError( - `Argument "booleanArg" expected type "Boolean" but got: TRUE.`, + "Argument \"booleanArg\" has invalid value TRUE.\nExpected type \"Boolean\", found TRUE.", 4, 41, ), }) @@ -335,7 +335,7 @@ func TestValidate_ArgValuesOfCorrectType_InvalidIDValue_FloatIntoID(t *testing.T `, []gqlerrors.FormattedError{ testutil.RuleError( - `Argument "idArg" expected type "ID" but got: 1.0.`, + "Argument \"idArg\" has invalid value 1.0.\nExpected type \"ID\", found 1.0.", 4, 31, ), }) @@ -350,7 +350,7 @@ func TestValidate_ArgValuesOfCorrectType_InvalidIDValue_BooleanIntoID(t *testing `, []gqlerrors.FormattedError{ testutil.RuleError( - `Argument "idArg" expected type "ID" but got: true.`, + "Argument \"idArg\" has invalid value true.\nExpected type \"ID\", found true.", 4, 31, ), }) @@ -365,7 +365,7 @@ func TestValidate_ArgValuesOfCorrectType_InvalidIDValue_UnquotedIntoID(t *testin `, []gqlerrors.FormattedError{ testutil.RuleError( - `Argument "idArg" expected type "ID" but got: SOMETHING.`, + "Argument \"idArg\" has invalid value SOMETHING.\nExpected type \"ID\", found SOMETHING.", 4, 31, ), }) @@ -381,7 +381,7 @@ func TestValidate_ArgValuesOfCorrectType_InvalidEnumValue_IntIntoEnum(t *testing `, []gqlerrors.FormattedError{ testutil.RuleError( - `Argument "dogCommand" expected type "DogCommand" but got: 2.`, + "Argument \"dogCommand\" has invalid value 2.\nExpected type \"DogCommand\", found 2.", 4, 41, ), }) @@ -396,7 +396,7 @@ func TestValidate_ArgValuesOfCorrectType_InvalidEnumValue_FloatIntoEnum(t *testi `, []gqlerrors.FormattedError{ testutil.RuleError( - `Argument "dogCommand" expected type "DogCommand" but got: 1.0.`, + "Argument \"dogCommand\" has invalid value 1.0.\nExpected type \"DogCommand\", found 1.0.", 4, 41, ), }) @@ -411,7 +411,7 @@ func TestValidate_ArgValuesOfCorrectType_InvalidEnumValue_StringIntoEnum(t *test `, []gqlerrors.FormattedError{ testutil.RuleError( - `Argument "dogCommand" expected type "DogCommand" but got: "SIT".`, + "Argument \"dogCommand\" has invalid value \"SIT\".\nExpected type \"DogCommand\", found \"SIT\".", 4, 41, ), }) @@ -426,7 +426,7 @@ func TestValidate_ArgValuesOfCorrectType_InvalidEnumValue_BooleanIntoEnum(t *tes `, []gqlerrors.FormattedError{ testutil.RuleError( - `Argument "dogCommand" expected type "DogCommand" but got: true.`, + "Argument \"dogCommand\" has invalid value true.\nExpected type \"DogCommand\", found true.", 4, 41, ), }) @@ -441,7 +441,7 @@ func TestValidate_ArgValuesOfCorrectType_InvalidEnumValue_UnknownEnumValueIntoEn `, []gqlerrors.FormattedError{ testutil.RuleError( - `Argument "dogCommand" expected type "DogCommand" but got: JUGGLE.`, + "Argument \"dogCommand\" has invalid value JUGGLE.\nExpected type \"DogCommand\", found JUGGLE.", 4, 41, ), }) @@ -456,7 +456,7 @@ func TestValidate_ArgValuesOfCorrectType_InvalidEnumValue_DifferentCaseEnumValue `, []gqlerrors.FormattedError{ testutil.RuleError( - `Argument "dogCommand" expected type "DogCommand" but got: sit.`, + "Argument \"dogCommand\" has invalid value sit.\nExpected type \"DogCommand\", found sit.", 4, 41, ), }) @@ -500,7 +500,7 @@ func TestValidate_ArgValuesOfCorrectType_InvalidListValue_IncorrectItemType(t *t `, []gqlerrors.FormattedError{ testutil.RuleError( - `Argument "stringListArg" expected type "[String]" but got: ["one", 2].`, + "Argument \"stringListArg\" has invalid value [\"one\", 2].\nIn element #1: Expected type \"String\", found 2.", 4, 47, ), }) @@ -515,7 +515,7 @@ func TestValidate_ArgValuesOfCorrectType_InvalidListValue_SingleValueOfIncorrent `, []gqlerrors.FormattedError{ testutil.RuleError( - `Argument "stringListArg" expected type "[String]" but got: 1.`, + "Argument \"stringListArg\" has invalid value 1.\nExpected type \"String\", found 1.", 4, 47, ), }) @@ -622,11 +622,11 @@ func TestValidate_ArgValuesOfCorrectType_InvalidNonNullableValue_IncorrectValueT `, []gqlerrors.FormattedError{ testutil.RuleError( - `Argument "req2" expected type "Int!" but got: "two".`, + "Argument \"req2\" has invalid value \"two\".\nExpected type \"Int\", found \"two\".", 4, 32, ), testutil.RuleError( - `Argument "req1" expected type "Int!" but got: "one".`, + "Argument \"req1\" has invalid value \"one\".\nExpected type \"Int\", found \"one\".", 4, 45, ), }) @@ -641,7 +641,7 @@ func TestValidate_ArgValuesOfCorrectType_InvalidNonNullableValue_IncorrectValueA `, []gqlerrors.FormattedError{ testutil.RuleError( - `Argument "req1" expected type "Int!" but got: "one".`, + "Argument \"req1\" has invalid value \"one\".\nExpected type \"Int\", found \"one\".", 4, 32, ), }) @@ -724,7 +724,7 @@ func TestValidate_ArgValuesOfCorrectType_InvalidInputObjectValue_PartialObject_M `, []gqlerrors.FormattedError{ testutil.RuleError( - `Argument "complexArg" expected type "ComplexInput" but got: {intField: 4}.`, + "Argument \"complexArg\" has invalid value {intField: 4}.\nIn field \"requiredField\": Expected \"Boolean!\", found null.", 4, 41, ), }) @@ -742,7 +742,7 @@ func TestValidate_ArgValuesOfCorrectType_InvalidInputObjectValue_PartialObject_I `, []gqlerrors.FormattedError{ testutil.RuleError( - `Argument "complexArg" expected type "ComplexInput" but got: {stringListField: ["one", 2], requiredField: true}.`, + "Argument \"complexArg\" has invalid value {stringListField: [\"one\", 2], requiredField: true}.\nIn field \"stringListField\": In element #1: Expected type \"String\", found 2.", 4, 41, ), }) @@ -760,7 +760,7 @@ func TestValidate_ArgValuesOfCorrectType_InvalidInputObjectValue_PartialObject_U `, []gqlerrors.FormattedError{ testutil.RuleError( - `Argument "complexArg" expected type "ComplexInput" but got: {requiredField: true, unknownField: "value"}.`, + "Argument \"complexArg\" has invalid value {requiredField: true, unknownField: \"value\"}.\nIn field \"unknownField\": Unknown field.", 4, 41, ), }) @@ -788,11 +788,13 @@ func TestValidate_ArgValuesOfCorrectType_DirectiveArguments_WithDirectivesWithIn `, []gqlerrors.FormattedError{ testutil.RuleError( - `Argument "if" expected type "Boolean!" but got: "yes".`, + `Argument "if" has invalid value "yes".`+ + "\nExpected type \"Boolean\", found \"yes\".", 3, 28, ), testutil.RuleError( - `Argument "if" expected type "Boolean!" but got: ENUM.`, + `Argument "if" has invalid value ENUM.`+ + "\nExpected type \"Boolean\", found ENUM.", 4, 28, ), }) diff --git a/rules_default_values_of_correct_type_test.go b/rules_default_values_of_correct_type_test.go index 8ef76210..bc9545be 100644 --- a/rules_default_values_of_correct_type_test.go +++ b/rules_default_values_of_correct_type_test.go @@ -63,9 +63,16 @@ func TestValidate_VariableDefaultValuesOfCorrectType_VariablesWithInvalidDefault } `, []gqlerrors.FormattedError{ - testutil.RuleError(`Variable "$a" of type "Int" has invalid default value: "one".`, 3, 19), - testutil.RuleError(`Variable "$b" of type "String" has invalid default value: 4.`, 4, 22), - testutil.RuleError(`Variable "$c" of type "ComplexInput" has invalid default value: "notverycomplex".`, 5, 28), + testutil.RuleError(`Variable "$a" has invalid default value: "one".`+ + "\nExpected type \"Int\", found \"one\".", + 3, 19), + testutil.RuleError(`Variable "$b" has invalid default value: 4.`+ + "\nExpected type \"String\", found 4.", + 4, 22), + testutil.RuleError( + `Variable "$c" has invalid default value: "notverycomplex".`+ + "\nExpected \"ComplexInput\", found not an object.", + 5, 28), }) } func TestValidate_VariableDefaultValuesOfCorrectType_ComplexVariablesMissingRequiredField(t *testing.T) { @@ -75,7 +82,10 @@ func TestValidate_VariableDefaultValuesOfCorrectType_ComplexVariablesMissingRequ } `, []gqlerrors.FormattedError{ - testutil.RuleError(`Variable "$a" of type "ComplexInput" has invalid default value: {intField: 3}.`, 2, 53), + testutil.RuleError( + `Variable "$a" has invalid default value: {intField: 3}.`+ + "\nIn field \"requiredField\": Expected \"Boolean!\", found null.", + 2, 53), }) } func TestValidate_VariableDefaultValuesOfCorrectType_ListVariablesWithInvalidItem(t *testing.T) { @@ -85,6 +95,9 @@ func TestValidate_VariableDefaultValuesOfCorrectType_ListVariablesWithInvalidIte } `, []gqlerrors.FormattedError{ - testutil.RuleError(`Variable "$a" of type "[String]" has invalid default value: ["one", 2].`, 2, 40), + testutil.RuleError( + `Variable "$a" has invalid default value: ["one", 2].`+ + "\nIn element #1: Expected type \"String\", found 2.", + 2, 40), }) } diff --git a/rules_fields_on_correct_type_test.go b/rules_fields_on_correct_type_test.go index af1f571e..294a0682 100644 --- a/rules_fields_on_correct_type_test.go +++ b/rules_fields_on_correct_type_test.go @@ -53,16 +53,30 @@ func TestValidate_FieldsOnCorrectType_IgnoresFieldsOnUnknownType(t *testing.T) { } `) } +func TestValidate_FieldsOnCorrectType_ReportErrorsWhenTheTypeIsKnownAgain(t *testing.T) { + testutil.ExpectFailsRule(t, graphql.FieldsOnCorrectTypeRule, ` + fragment typeKnownAgain on Pet { + unknown_pet_field { + ... on Cat { + unknown_cat_field + } + } + } + `, []gqlerrors.FormattedError{ + testutil.RuleError(`Cannot query field "unknown_pet_field" on type "Pet".`, 3, 9), + testutil.RuleError(`Cannot query field "unknown_cat_field" on type "Cat".`, 5, 13), + }) +} func TestValidate_FieldsOnCorrectType_FieldNotDefinedOnFragment(t *testing.T) { testutil.ExpectFailsRule(t, graphql.FieldsOnCorrectTypeRule, ` fragment fieldNotDefined on Dog { meowVolume } `, []gqlerrors.FormattedError{ - testutil.RuleError(`Cannot query field "meowVolume" on "Dog".`, 3, 9), + testutil.RuleError(`Cannot query field "meowVolume" on type "Dog".`, 3, 9), }) } -func TestValidate_FieldsOnCorrectType_FieldNotDefinedDeeplyOnlyReportsFirst(t *testing.T) { +func TestValidate_FieldsOnCorrectType_IgnoreDeeplyUnknownField(t *testing.T) { testutil.ExpectFailsRule(t, graphql.FieldsOnCorrectTypeRule, ` fragment deepFieldNotDefined on Dog { unknown_field { @@ -70,7 +84,7 @@ func TestValidate_FieldsOnCorrectType_FieldNotDefinedDeeplyOnlyReportsFirst(t *t } } `, []gqlerrors.FormattedError{ - testutil.RuleError(`Cannot query field "unknown_field" on "Dog".`, 3, 9), + testutil.RuleError(`Cannot query field "unknown_field" on type "Dog".`, 3, 9), }) } func TestValidate_FieldsOnCorrectType_SubFieldNotDefined(t *testing.T) { @@ -81,7 +95,7 @@ func TestValidate_FieldsOnCorrectType_SubFieldNotDefined(t *testing.T) { } } `, []gqlerrors.FormattedError{ - testutil.RuleError(`Cannot query field "unknown_field" on "Pet".`, 4, 11), + testutil.RuleError(`Cannot query field "unknown_field" on type "Pet".`, 4, 11), }) } func TestValidate_FieldsOnCorrectType_FieldNotDefinedOnInlineFragment(t *testing.T) { @@ -92,7 +106,7 @@ func TestValidate_FieldsOnCorrectType_FieldNotDefinedOnInlineFragment(t *testing } } `, []gqlerrors.FormattedError{ - testutil.RuleError(`Cannot query field "meowVolume" on "Dog".`, 4, 11), + testutil.RuleError(`Cannot query field "meowVolume" on type "Dog".`, 4, 11), }) } func TestValidate_FieldsOnCorrectType_AliasedFieldTargetNotDefined(t *testing.T) { @@ -101,7 +115,7 @@ func TestValidate_FieldsOnCorrectType_AliasedFieldTargetNotDefined(t *testing.T) volume : mooVolume } `, []gqlerrors.FormattedError{ - testutil.RuleError(`Cannot query field "mooVolume" on "Dog".`, 3, 9), + testutil.RuleError(`Cannot query field "mooVolume" on type "Dog".`, 3, 9), }) } func TestValidate_FieldsOnCorrectType_AliasedLyingFieldTargetNotDefined(t *testing.T) { @@ -110,7 +124,7 @@ func TestValidate_FieldsOnCorrectType_AliasedLyingFieldTargetNotDefined(t *testi barkVolume : kawVolume } `, []gqlerrors.FormattedError{ - testutil.RuleError(`Cannot query field "kawVolume" on "Dog".`, 3, 9), + testutil.RuleError(`Cannot query field "kawVolume" on type "Dog".`, 3, 9), }) } func TestValidate_FieldsOnCorrectType_NotDefinedOnInterface(t *testing.T) { @@ -119,7 +133,7 @@ func TestValidate_FieldsOnCorrectType_NotDefinedOnInterface(t *testing.T) { tailLength } `, []gqlerrors.FormattedError{ - testutil.RuleError(`Cannot query field "tailLength" on "Pet".`, 3, 9), + testutil.RuleError(`Cannot query field "tailLength" on type "Pet".`, 3, 9), }) } func TestValidate_FieldsOnCorrectType_DefinedOnImplementorsButNotOnInterface(t *testing.T) { @@ -128,7 +142,7 @@ func TestValidate_FieldsOnCorrectType_DefinedOnImplementorsButNotOnInterface(t * nickname } `, []gqlerrors.FormattedError{ - testutil.RuleError(`Cannot query field "nickname" on "Pet".`, 3, 9), + testutil.RuleError(`Cannot query field "nickname" on type "Pet". However, this field exists on "Cat", "Dog". Perhaps you meant to use an inline fragment?`, 3, 9), }) } func TestValidate_FieldsOnCorrectType_MetaFieldSelectionOnUnion(t *testing.T) { @@ -144,16 +158,16 @@ func TestValidate_FieldsOnCorrectType_DirectFieldSelectionOnUnion(t *testing.T) directField } `, []gqlerrors.FormattedError{ - testutil.RuleError(`Cannot query field "directField" on "CatOrDog".`, 3, 9), + testutil.RuleError(`Cannot query field "directField" on type "CatOrDog".`, 3, 9), }) } -func TestValidate_FieldsOnCorrectType_DirectImplementorsQueriedOnUnion(t *testing.T) { +func TestValidate_FieldsOnCorrectType_DefinedImplementorsQueriedOnUnion(t *testing.T) { testutil.ExpectFailsRule(t, graphql.FieldsOnCorrectTypeRule, ` fragment definedOnImplementorsQueriedOnUnion on CatOrDog { name } `, []gqlerrors.FormattedError{ - testutil.RuleError(`Cannot query field "name" on "CatOrDog".`, 3, 9), + testutil.RuleError(`Cannot query field "name" on type "CatOrDog". However, this field exists on "Being", "Pet", "Canine", "Cat", "Dog". Perhaps you meant to use an inline fragment?`, 3, 9), }) } func TestValidate_FieldsOnCorrectType_ValidFieldInInlineFragment(t *testing.T) { @@ -162,6 +176,36 @@ func TestValidate_FieldsOnCorrectType_ValidFieldInInlineFragment(t *testing.T) { ... on Dog { name } + ... { + name + } } `) } + +func TestValidate_FieldsOnCorrectTypeErrorMessage_WorksWithNoSuggestions(t *testing.T) { + message := graphql.UndefinedFieldMessage("T", "f", []string{}) + expected := `Cannot query field "T" on type "f".` + if message != expected { + t.Fatalf("Unexpected message, expected: %v, got %v", expected, message) + } +} + +func TestValidate_FieldsOnCorrectTypeErrorMessage_WorksWithNoSmallNumbersOfSuggestions(t *testing.T) { + message := graphql.UndefinedFieldMessage("T", "f", []string{"A", "B"}) + expected := `Cannot query field "T" on type "f". ` + + `However, this field exists on "A", "B". ` + + `Perhaps you meant to use an inline fragment?` + if message != expected { + t.Fatalf("Unexpected message, expected: %v, got %v", expected, message) + } +} +func TestValidate_FieldsOnCorrectTypeErrorMessage_WorksWithLotsOfSuggestions(t *testing.T) { + message := graphql.UndefinedFieldMessage("T", "f", []string{"A", "B", "C", "D", "E", "F"}) + expected := `Cannot query field "T" on type "f". ` + + `However, this field exists on "A", "B", "C", "D", "E", and 1 other types. ` + + `Perhaps you meant to use an inline fragment?` + if message != expected { + t.Fatalf("Unexpected message, expected: %v, got %v", expected, message) + } +} diff --git a/rules_fragments_on_composite_types_test.go b/rules_fragments_on_composite_types_test.go index 31fbf08b..efe072ab 100644 --- a/rules_fragments_on_composite_types_test.go +++ b/rules_fragments_on_composite_types_test.go @@ -31,6 +31,15 @@ func TestValidate_FragmentsOnCompositeTypes_ObjectIsValidInlineFragmentType(t *t } `) } +func TestValidate_FragmentsOnCompositeTypes_InlineFragmentWithoutTypeIsValid(t *testing.T) { + testutil.ExpectPassesRule(t, graphql.FragmentsOnCompositeTypesRule, ` + fragment validFragment on Pet { + ... { + name + } + } + `) +} func TestValidate_FragmentsOnCompositeTypes_UnionIsValidFragmentType(t *testing.T) { testutil.ExpectPassesRule(t, graphql.FragmentsOnCompositeTypesRule, ` fragment validFragment on CatOrDog { diff --git a/rules_known_directives_rule_test.go b/rules_known_directives_rule_test.go index 0ece1888..1a5e7d5e 100644 --- a/rules_known_directives_rule_test.go +++ b/rules_known_directives_rule_test.go @@ -75,10 +75,12 @@ func TestValidate_KnownDirectives_WithWellPlacedDirectives(t *testing.T) { func TestValidate_KnownDirectives_WithMisplacedDirectives(t *testing.T) { testutil.ExpectFailsRule(t, graphql.KnownDirectivesRule, ` query Foo @include(if: true) { - name - ...Frag + name @operationOnly + ...Frag @operationOnly } `, []gqlerrors.FormattedError{ testutil.RuleError(`Directive "include" may not be used on "operation".`, 2, 17), + testutil.RuleError(`Directive "operationOnly" may not be used on "field".`, 3, 14), + testutil.RuleError(`Directive "operationOnly" may not be used on "fragment".`, 4, 17), }) } diff --git a/rules_known_fragment_names_test.go b/rules_known_fragment_names_test.go index b3d5d52e..eb522b26 100644 --- a/rules_known_fragment_names_test.go +++ b/rules_known_fragment_names_test.go @@ -16,6 +16,9 @@ func TestValidate_KnownFragmentNames_KnownFragmentNamesAreValid(t *testing.T) { ... on Human { ...HumanFields2 } + ... { + name + } } } fragment HumanFields1 on Human { diff --git a/rules_known_type_names_test.go b/rules_known_type_names_test.go index 00c70263..eec9a0ae 100644 --- a/rules_known_type_names_test.go +++ b/rules_known_type_names_test.go @@ -12,7 +12,7 @@ func TestValidate_KnownTypeNames_KnownTypeNamesAreValid(t *testing.T) { testutil.ExpectPassesRule(t, graphql.KnownTypeNamesRule, ` query Foo($var: String, $required: [String!]!) { user(id: 4) { - pets { ... on Pet { name }, ...PetFields } + pets { ... on Pet { name }, ...PetFields, ... { name } } } } fragment PetFields on Pet { @@ -37,3 +37,25 @@ func TestValidate_KnownTypeNames_UnknownTypeNamesAreInValid(t *testing.T) { testutil.RuleError(`Unknown type "Peettt".`, 8, 29), }) } + +func TestValidate_KnownTypeNames_IgnoresTypeDefinitions(t *testing.T) { + testutil.ExpectFailsRule(t, graphql.KnownTypeNamesRule, ` + type NotInTheSchema { + field: FooBar + } + interface FooBar { + field: NotInTheSchema + } + union U = A | B + input Blob { + field: UnknownType + } + query Foo($var: NotInTheSchema) { + user(id: $var) { + id + } + } + `, []gqlerrors.FormattedError{ + testutil.RuleError(`Unknown type "NotInTheSchema".`, 12, 23), + }) +} diff --git a/rules_lone_anonymous_operation_rule_test.go b/rules_lone_anonymous_operation_rule_test.go index cefaff64..8fb6894f 100644 --- a/rules_lone_anonymous_operation_rule_test.go +++ b/rules_lone_anonymous_operation_rule_test.go @@ -56,7 +56,20 @@ func TestValidate_AnonymousOperationMustBeAlone_MultipleAnonOperations(t *testin testutil.RuleError(`This anonymous operation must be the only defined operation.`, 5, 7), }) } -func TestValidate_AnonymousOperationMustBeAlone_AnonOperationWithAnotherOperation(t *testing.T) { +func TestValidate_AnonymousOperationMustBeAlone_AnonOperationWithAMutation(t *testing.T) { + testutil.ExpectFailsRule(t, graphql.LoneAnonymousOperationRule, ` + { + fieldA + } + mutation Foo { + fieldB + } + `, []gqlerrors.FormattedError{ + testutil.RuleError(`This anonymous operation must be the only defined operation.`, 2, 7), + }) +} + +func TestValidate_AnonymousOperationMustBeAlone_AnonOperationWithASubscription(t *testing.T) { testutil.ExpectFailsRule(t, graphql.LoneAnonymousOperationRule, ` { fieldA diff --git a/rules_no_fragment_cycles_test.go b/rules_no_fragment_cycles_test.go index 0eabdb77..f194e305 100644 --- a/rules_no_fragment_cycles_test.go +++ b/rules_no_fragment_cycles_test.go @@ -40,6 +40,13 @@ func TestValidate_NoCircularFragmentSpreads_DoubleSpreadWithinAbstractTypes(t *t } `) } +func TestValidate_NoCircularFragmentSpreads_DoesNotFalsePositiveOnUnknownFragment(t *testing.T) { + testutil.ExpectPassesRule(t, graphql.NoFragmentCyclesRule, ` + fragment nameFragment on Pet { + ...UnknownFragment + } + `) +} func TestValidate_NoCircularFragmentSpreads_SpreadingRecursivelyWithinFieldFails(t *testing.T) { testutil.ExpectFailsRule(t, graphql.NoFragmentCyclesRule, ` fragment fragA on Human { relatives { ...fragA } }, @@ -108,10 +115,21 @@ func TestValidate_NoCircularFragmentSpreads_NoSpreadingItselfDeeply(t *testing.T fragment fragX on Dog { ...fragY } fragment fragY on Dog { ...fragZ } fragment fragZ on Dog { ...fragO } - fragment fragO on Dog { ...fragA, ...fragX } + fragment fragO on Dog { ...fragP } + fragment fragP on Dog { ...fragA, ...fragX } `, []gqlerrors.FormattedError{ - testutil.RuleError(`Cannot spread fragment "fragA" within itself via fragB, fragC, fragO.`, 2, 31, 3, 31, 4, 31, 8, 31), - testutil.RuleError(`Cannot spread fragment "fragX" within itself via fragY, fragZ, fragO.`, 5, 31, 6, 31, 7, 31, 8, 41), + testutil.RuleError(`Cannot spread fragment "fragA" within itself via fragB, fragC, fragO, fragP.`, + 2, 31, + 3, 31, + 4, 31, + 8, 31, + 9, 31), + testutil.RuleError(`Cannot spread fragment "fragO" within itself via fragP, fragX, fragY, fragZ.`, + 8, 31, + 9, 41, + 5, 31, + 6, 31, + 7, 31), }) } func TestValidate_NoCircularFragmentSpreads_NoSpreadingItselfDeeplyTwoPaths(t *testing.T) { @@ -120,7 +138,41 @@ func TestValidate_NoCircularFragmentSpreads_NoSpreadingItselfDeeplyTwoPaths(t *t fragment fragB on Dog { ...fragA } fragment fragC on Dog { ...fragA } `, []gqlerrors.FormattedError{ - testutil.RuleError(`Cannot spread fragment "fragA" within itself via fragB.`, 2, 31, 3, 31), - testutil.RuleError(`Cannot spread fragment "fragA" within itself via fragC.`, 2, 41, 4, 31), + testutil.RuleError(`Cannot spread fragment "fragA" within itself via fragB.`, + 2, 31, + 3, 31), + testutil.RuleError(`Cannot spread fragment "fragA" within itself via fragC.`, + 2, 41, + 4, 31), + }) +} +func TestValidate_NoCircularFragmentSpreads_NoSpreadingItselfDeeplyTwoPaths_AltTraverseOrder(t *testing.T) { + testutil.ExpectFailsRule(t, graphql.NoFragmentCyclesRule, ` + fragment fragA on Dog { ...fragC } + fragment fragB on Dog { ...fragC } + fragment fragC on Dog { ...fragA, ...fragB } + `, []gqlerrors.FormattedError{ + testutil.RuleError(`Cannot spread fragment "fragA" within itself via fragC.`, + 2, 31, + 4, 31), + testutil.RuleError(`Cannot spread fragment "fragC" within itself via fragB.`, + 4, 41, + 3, 31), + }) +} +func TestValidate_NoCircularFragmentSpreads_NoSpreadingItselfDeeplyAndImmediately(t *testing.T) { + testutil.ExpectFailsRule(t, graphql.NoFragmentCyclesRule, ` + fragment fragA on Dog { ...fragB } + fragment fragB on Dog { ...fragB, ...fragC } + fragment fragC on Dog { ...fragA, ...fragB } + `, []gqlerrors.FormattedError{ + testutil.RuleError(`Cannot spread fragment "fragB" within itself.`, 3, 31), + testutil.RuleError(`Cannot spread fragment "fragA" within itself via fragB, fragC.`, + 2, 31, + 3, 41, + 4, 31), + testutil.RuleError(`Cannot spread fragment "fragB" within itself via fragC.`, + 3, 41, + 4, 41), }) } diff --git a/rules_no_undefined_variables_test.go b/rules_no_undefined_variables_test.go index 64449842..0b253715 100644 --- a/rules_no_undefined_variables_test.go +++ b/rules_no_undefined_variables_test.go @@ -108,7 +108,7 @@ func TestValidate_NoUndefinedVariables_VariableNotDefined(t *testing.T) { field(a: $a, b: $b, c: $c, d: $d) } `, []gqlerrors.FormattedError{ - testutil.RuleError(`Variable "$d" is not defined.`, 3, 39), + testutil.RuleError(`Variable "$d" is not defined by operation "Foo".`, 3, 39, 2, 7), }) } func TestValidate_NoUndefinedVariables_VariableNotDefinedByUnnamedQuery(t *testing.T) { @@ -117,7 +117,7 @@ func TestValidate_NoUndefinedVariables_VariableNotDefinedByUnnamedQuery(t *testi field(a: $a) } `, []gqlerrors.FormattedError{ - testutil.RuleError(`Variable "$a" is not defined.`, 3, 18), + testutil.RuleError(`Variable "$a" is not defined.`, 3, 18, 2, 7), }) } func TestValidate_NoUndefinedVariables_MultipleVariablesNotDefined(t *testing.T) { @@ -126,8 +126,8 @@ func TestValidate_NoUndefinedVariables_MultipleVariablesNotDefined(t *testing.T) field(a: $a, b: $b, c: $c) } `, []gqlerrors.FormattedError{ - testutil.RuleError(`Variable "$a" is not defined.`, 3, 18), - testutil.RuleError(`Variable "$c" is not defined.`, 3, 32), + testutil.RuleError(`Variable "$a" is not defined by operation "Foo".`, 3, 18, 2, 7), + testutil.RuleError(`Variable "$c" is not defined by operation "Foo".`, 3, 32, 2, 7), }) } func TestValidate_NoUndefinedVariables_VariableInFragmentNotDefinedByUnnamedQuery(t *testing.T) { @@ -139,7 +139,7 @@ func TestValidate_NoUndefinedVariables_VariableInFragmentNotDefinedByUnnamedQuer field(a: $a) } `, []gqlerrors.FormattedError{ - testutil.RuleError(`Variable "$a" is not defined.`, 6, 18), + testutil.RuleError(`Variable "$a" is not defined.`, 6, 18, 2, 7), }) } func TestValidate_NoUndefinedVariables_VariableInFragmentNotDefinedByOperation(t *testing.T) { diff --git a/rules_no_unused_variables_test.go b/rules_no_unused_variables_test.go index d3bcdae4..7c331f4a 100644 --- a/rules_no_unused_variables_test.go +++ b/rules_no_unused_variables_test.go @@ -10,7 +10,7 @@ import ( func TestValidate_NoUnusedVariables_UsesAllVariables(t *testing.T) { testutil.ExpectPassesRule(t, graphql.NoUnusedVariablesRule, ` - query Foo($a: String, $b: String, $c: String) { + query ($a: String, $b: String, $c: String) { field(a: $a, b: $b, c: $c) } `) @@ -91,11 +91,11 @@ func TestValidate_NoUnusedVariables_VariableUsedByRecursiveFragment(t *testing.T } func TestValidate_NoUnusedVariables_VariableNotUsed(t *testing.T) { testutil.ExpectFailsRule(t, graphql.NoUnusedVariablesRule, ` - query Foo($a: String, $b: String, $c: String) { + query ($a: String, $b: String, $c: String) { field(a: $a, b: $b) } `, []gqlerrors.FormattedError{ - testutil.RuleError(`Variable "$c" is never used.`, 2, 41), + testutil.RuleError(`Variable "$c" is never used.`, 2, 38), }) } func TestValidate_NoUnusedVariables_MultipleVariablesNotUsed(t *testing.T) { @@ -104,8 +104,8 @@ func TestValidate_NoUnusedVariables_MultipleVariablesNotUsed(t *testing.T) { field(b: $b) } `, []gqlerrors.FormattedError{ - testutil.RuleError(`Variable "$a" is never used.`, 2, 17), - testutil.RuleError(`Variable "$c" is never used.`, 2, 41), + testutil.RuleError(`Variable "$a" is never used in operation "Foo".`, 2, 17), + testutil.RuleError(`Variable "$c" is never used in operation "Foo".`, 2, 41), }) } func TestValidate_NoUnusedVariables_VariableNotUsedInFragments(t *testing.T) { @@ -127,7 +127,7 @@ func TestValidate_NoUnusedVariables_VariableNotUsedInFragments(t *testing.T) { field } `, []gqlerrors.FormattedError{ - testutil.RuleError(`Variable "$c" is never used.`, 2, 41), + testutil.RuleError(`Variable "$c" is never used in operation "Foo".`, 2, 41), }) } func TestValidate_NoUnusedVariables_MultipleVariablesNotUsed2(t *testing.T) { @@ -149,8 +149,8 @@ func TestValidate_NoUnusedVariables_MultipleVariablesNotUsed2(t *testing.T) { field } `, []gqlerrors.FormattedError{ - testutil.RuleError(`Variable "$a" is never used.`, 2, 17), - testutil.RuleError(`Variable "$c" is never used.`, 2, 41), + testutil.RuleError(`Variable "$a" is never used in operation "Foo".`, 2, 17), + testutil.RuleError(`Variable "$c" is never used in operation "Foo".`, 2, 41), }) } func TestValidate_NoUnusedVariables_VariableNotUsedByUnreferencedFragment(t *testing.T) { @@ -165,7 +165,7 @@ func TestValidate_NoUnusedVariables_VariableNotUsedByUnreferencedFragment(t *tes field(b: $b) } `, []gqlerrors.FormattedError{ - testutil.RuleError(`Variable "$b" is never used.`, 2, 17), + testutil.RuleError(`Variable "$b" is never used in operation "Foo".`, 2, 17), }) } func TestValidate_NoUnusedVariables_VariableNotUsedByFragmentUsedByOtherOperation(t *testing.T) { @@ -183,7 +183,7 @@ func TestValidate_NoUnusedVariables_VariableNotUsedByFragmentUsedByOtherOperatio field(b: $b) } `, []gqlerrors.FormattedError{ - testutil.RuleError(`Variable "$b" is never used.`, 2, 17), - testutil.RuleError(`Variable "$a" is never used.`, 5, 17), + testutil.RuleError(`Variable "$b" is never used in operation "Foo".`, 2, 17), + testutil.RuleError(`Variable "$a" is never used in operation "Bar".`, 5, 17), }) } diff --git a/rules_overlapping_fields_can_be_merged_test.go b/rules_overlapping_fields_can_be_merged_test.go index 755c8bbe..903367ea 100644 --- a/rules_overlapping_fields_can_be_merged_test.go +++ b/rules_overlapping_fields_can_be_merged_test.go @@ -56,6 +56,17 @@ func TestValidate_OverlappingFieldsCanBeMerged_DifferentDirectivesWithDifferentA } `) } +func TestValidate_OverlappingFieldsCanBeMerged_DifferentSkipIncludeDirectivesAccepted(t *testing.T) { + // Note: Differing skip/include directives don't create an ambiguous return + // value and are acceptable in conditions where differing runtime values + // may have the same desired effect of including or skipping a field. + testutil.ExpectPassesRule(t, graphql.OverlappingFieldsCanBeMergedRule, ` + fragment differentDirectivesWithDifferentAliases on Dog { + name @include(if: true) + name @include(if: false) + } + `) +} func TestValidate_OverlappingFieldsCanBeMerged_SameAliasesWithDifferentFieldTargets(t *testing.T) { testutil.ExpectFailsRule(t, graphql.OverlappingFieldsCanBeMergedRule, ` fragment sameAliasesWithDifferentFieldTargets on Dog { @@ -66,7 +77,19 @@ func TestValidate_OverlappingFieldsCanBeMerged_SameAliasesWithDifferentFieldTarg testutil.RuleError(`Fields "fido" conflict because name and nickname are different fields.`, 3, 9, 4, 9), }) } -func TestValidate_OverlappingFieldsCanBeMerged_AliasMakingDirectFieldAccess(t *testing.T) { +func TestValidate_OverlappingFieldsCanBeMerged_SameAliasesAllowedOnNonOverlappingFields(t *testing.T) { + testutil.ExpectPassesRule(t, graphql.OverlappingFieldsCanBeMergedRule, ` + fragment sameAliasesWithDifferentFieldTargets on Pet { + ... on Dog { + name + } + ... on Cat { + name: nickname + } + } + `) +} +func TestValidate_OverlappingFieldsCanBeMerged_AliasMaskingDirectFieldAccess(t *testing.T) { testutil.ExpectFailsRule(t, graphql.OverlappingFieldsCanBeMergedRule, ` fragment aliasMaskingDirectFieldAccess on Dog { name: nickname @@ -76,55 +99,49 @@ func TestValidate_OverlappingFieldsCanBeMerged_AliasMakingDirectFieldAccess(t *t testutil.RuleError(`Fields "name" conflict because nickname and name are different fields.`, 3, 9, 4, 9), }) } -func TestValidate_OverlappingFieldsCanBeMerged_ConflictingArgs(t *testing.T) { +func TestValidate_OverlappingFieldsCanBeMerged_DifferentArgs_SecondAddsAnArgument(t *testing.T) { testutil.ExpectFailsRule(t, graphql.OverlappingFieldsCanBeMergedRule, ` fragment conflictingArgs on Dog { - doesKnowCommand(dogCommand: SIT) + doesKnowCommand doesKnowCommand(dogCommand: HEEL) } `, []gqlerrors.FormattedError{ testutil.RuleError(`Fields "doesKnowCommand" conflict because they have differing arguments.`, 3, 9, 4, 9), }) } -func TestValidate_OverlappingFieldsCanBeMerged_ConflictingDirectives(t *testing.T) { +func TestValidate_OverlappingFieldsCanBeMerged_DifferentArgs_SecondMissingAnArgument(t *testing.T) { testutil.ExpectFailsRule(t, graphql.OverlappingFieldsCanBeMergedRule, ` - fragment conflictingDirectiveArgs on Dog { - name @include(if: true) - name @skip(if: false) - } - `, []gqlerrors.FormattedError{ - testutil.RuleError(`Fields "name" conflict because they have differing directives.`, 3, 9, 4, 9), - }) -} -func TestValidate_OverlappingFieldsCanBeMerged_ConflictingDirectiveArgs(t *testing.T) { - testutil.ExpectFailsRule(t, graphql.OverlappingFieldsCanBeMergedRule, ` - fragment conflictingDirectiveArgs on Dog { - name @include(if: true) - name @include(if: false) + fragment conflictingArgs on Dog { + doesKnowCommand(dogCommand: SIT) + doesKnowCommand } `, []gqlerrors.FormattedError{ - testutil.RuleError(`Fields "name" conflict because they have differing directives.`, 3, 9, 4, 9), + testutil.RuleError(`Fields "doesKnowCommand" conflict because they have differing arguments.`, 3, 9, 4, 9), }) } -func TestValidate_OverlappingFieldsCanBeMerged_ConflictingArgsWithMatchingDirectives(t *testing.T) { +func TestValidate_OverlappingFieldsCanBeMerged_ConflictingArgs(t *testing.T) { testutil.ExpectFailsRule(t, graphql.OverlappingFieldsCanBeMergedRule, ` - fragment conflictingArgsWithMatchingDirectiveArgs on Dog { - doesKnowCommand(dogCommand: SIT) @include(if: true) - doesKnowCommand(dogCommand: HEEL) @include(if: true) + fragment conflictingArgs on Dog { + doesKnowCommand(dogCommand: SIT) + doesKnowCommand(dogCommand: HEEL) } `, []gqlerrors.FormattedError{ testutil.RuleError(`Fields "doesKnowCommand" conflict because they have differing arguments.`, 3, 9, 4, 9), }) } -func TestValidate_OverlappingFieldsCanBeMerged_ConflictingDirectivesWithMatchingArgs(t *testing.T) { - testutil.ExpectFailsRule(t, graphql.OverlappingFieldsCanBeMergedRule, ` - fragment conflictingDirectiveArgsWithMatchingArgs on Dog { - doesKnowCommand(dogCommand: SIT) @include(if: true) - doesKnowCommand(dogCommand: SIT) @skip(if: false) +func TestValidate_OverlappingFieldsCanBeMerged_AllowDifferentArgsWhereNoConflictIsPossible(t *testing.T) { + // This is valid since no object can be both a "Dog" and a "Cat", thus + // these fields can never overlap. + testutil.ExpectPassesRule(t, graphql.OverlappingFieldsCanBeMergedRule, ` + fragment conflictingArgs on Pet { + ... on Dog { + name(surname: true) + } + ... on Cat { + name + } } - `, []gqlerrors.FormattedError{ - testutil.RuleError(`Fields "doesKnowCommand" conflict because they have differing directives.`, 3, 9, 4, 9), - }) + `) } func TestValidate_OverlappingFieldsCanBeMerged_EncountersConflictInFragments(t *testing.T) { testutil.ExpectFailsRule(t, graphql.OverlappingFieldsCanBeMergedRule, ` @@ -183,7 +200,10 @@ func TestValidate_OverlappingFieldsCanBeMerged_DeepConflict(t *testing.T) { } `, []gqlerrors.FormattedError{ testutil.RuleError(`Fields "field" conflict because subfields "x" conflict because a and b are different fields.`, - 3, 9, 6, 9, 4, 11, 7, 11), + 3, 9, + 4, 11, + 6, 9, + 7, 11), }) } func TestValidate_OverlappingFieldsCanBeMerged_DeepConflictWithMultipleIssues(t *testing.T) { @@ -202,7 +222,12 @@ func TestValidate_OverlappingFieldsCanBeMerged_DeepConflictWithMultipleIssues(t testutil.RuleError( `Fields "field" conflict because subfields "x" conflict because a and b are different fields and `+ `subfields "y" conflict because c and d are different fields.`, - 3, 9, 7, 9, 4, 11, 8, 11, 5, 11, 9, 11), + 3, 9, + 4, 11, + 5, 11, + 7, 9, + 8, 11, + 9, 11), }) } func TestValidate_OverlappingFieldsCanBeMerged_VeryDeepConflict(t *testing.T) { @@ -223,7 +248,12 @@ func TestValidate_OverlappingFieldsCanBeMerged_VeryDeepConflict(t *testing.T) { testutil.RuleError( `Fields "field" conflict because subfields "deepField" conflict because subfields "x" conflict because `+ `a and b are different fields.`, - 3, 9, 8, 9, 4, 11, 9, 11, 5, 13, 10, 13), + 3, 9, + 4, 11, + 5, 13, + 8, 9, + 9, 11, + 10, 13), }) } func TestValidate_OverlappingFieldsCanBeMerged_ReportsDeepConflictToNearestCommonAncestor(t *testing.T) { @@ -247,116 +277,194 @@ func TestValidate_OverlappingFieldsCanBeMerged_ReportsDeepConflictToNearestCommo testutil.RuleError( `Fields "deepField" conflict because subfields "x" conflict because `+ `a and b are different fields.`, - 4, 11, 7, 11, 5, 13, 8, 13), + 4, 11, + 5, 13, + 7, 11, + 8, 13), }) } -var stringBoxObject = graphql.NewObject(graphql.ObjectConfig{ - Name: "StringBox", - Fields: graphql.Fields{ - "scalar": &graphql.Field{ - Type: graphql.String, +var someBoxInterface *graphql.Interface +var stringBoxObject *graphql.Object +var schema graphql.Schema + +func init() { + someBoxInterface = graphql.NewInterface(graphql.InterfaceConfig{ + Name: "SomeBox", + ResolveType: func(value interface{}, info graphql.ResolveInfo) *graphql.Object { + return stringBoxObject }, - }, -}) -var intBoxObject = graphql.NewObject(graphql.ObjectConfig{ - Name: "IntBox", - Fields: graphql.Fields{ - "scalar": &graphql.Field{ - Type: graphql.Int, + Fields: graphql.Fields{ + "unrelatedField": &graphql.Field{ + Type: graphql.String, + }, }, - }, -}) -var nonNullStringBox1Object = graphql.NewObject(graphql.ObjectConfig{ - Name: "NonNullStringBox1", - Fields: graphql.Fields{ - "scalar": &graphql.Field{ - Type: graphql.NewNonNull(graphql.String), + }) + stringBoxObject = graphql.NewObject(graphql.ObjectConfig{ + Name: "StringBox", + Interfaces: (graphql.InterfacesThunk)(func() []*graphql.Interface { + return []*graphql.Interface{someBoxInterface} + }), + Fields: graphql.Fields{ + "scalar": &graphql.Field{ + Type: graphql.String, + }, + "unrelatedField": &graphql.Field{ + Type: graphql.String, + }, }, - }, -}) -var nonNullStringBox2Object = graphql.NewObject(graphql.ObjectConfig{ - Name: "NonNullStringBox2", - Fields: graphql.Fields{ - "scalar": &graphql.Field{ - Type: graphql.NewNonNull(graphql.String), + }) + _ = graphql.NewObject(graphql.ObjectConfig{ + Name: "IntBox", + Interfaces: (graphql.InterfacesThunk)(func() []*graphql.Interface { + return []*graphql.Interface{someBoxInterface} + }), + Fields: graphql.Fields{ + "scalar": &graphql.Field{ + Type: graphql.Int, + }, + "unrelatedField": &graphql.Field{ + Type: graphql.String, + }, }, - }, -}) -var boxUnionObject = graphql.NewUnion(graphql.UnionConfig{ - Name: "BoxUnion", - ResolveType: func(value interface{}, info graphql.ResolveInfo) *graphql.Object { - return stringBoxObject - }, - Types: []*graphql.Object{ - stringBoxObject, - intBoxObject, - nonNullStringBox1Object, - nonNullStringBox2Object, - }, -}) - -var connectionObject = graphql.NewObject(graphql.ObjectConfig{ - Name: "Connection", - Fields: graphql.Fields{ - "edges": &graphql.Field{ - Type: graphql.NewList(graphql.NewObject(graphql.ObjectConfig{ - Name: "Edge", - Fields: graphql.Fields{ - "node": &graphql.Field{ - Type: graphql.NewObject(graphql.ObjectConfig{ - Name: "Node", - Fields: graphql.Fields{ - "id": &graphql.Field{ - Type: graphql.ID, - }, - "name": &graphql.Field{ - Type: graphql.String, - }, - }, - }), - }, - }, - })), + }) + var nonNullStringBox1Interface = graphql.NewInterface(graphql.InterfaceConfig{ + Name: "NonNullStringBox1", + ResolveType: func(value interface{}, info graphql.ResolveInfo) *graphql.Object { + return stringBoxObject }, - }, -}) -var schema, _ = graphql.NewSchema(graphql.SchemaConfig{ - Query: graphql.NewObject(graphql.ObjectConfig{ - Name: "QueryRoot", Fields: graphql.Fields{ - "boxUnion": &graphql.Field{ - Type: boxUnionObject, + "scalar": &graphql.Field{ + Type: graphql.NewNonNull(graphql.String), + }, + }, + }) + _ = graphql.NewObject(graphql.ObjectConfig{ + Name: "NonNullStringBox1Impl", + Interfaces: (graphql.InterfacesThunk)(func() []*graphql.Interface { + return []*graphql.Interface{someBoxInterface, nonNullStringBox1Interface} + }), + Fields: graphql.Fields{ + "scalar": &graphql.Field{ + Type: graphql.NewNonNull(graphql.String), }, - "connection": &graphql.Field{ - Type: connectionObject, + "unrelatedField": &graphql.Field{ + Type: graphql.String, }, }, - }), -}) + }) + var nonNullStringBox2Interface = graphql.NewInterface(graphql.InterfaceConfig{ + Name: "NonNullStringBox2", + ResolveType: func(value interface{}, info graphql.ResolveInfo) *graphql.Object { + return stringBoxObject + }, + Fields: graphql.Fields{ + "scalar": &graphql.Field{ + Type: graphql.NewNonNull(graphql.String), + }, + }, + }) + _ = graphql.NewObject(graphql.ObjectConfig{ + Name: "NonNullStringBox2Impl", + Interfaces: (graphql.InterfacesThunk)(func() []*graphql.Interface { + return []*graphql.Interface{someBoxInterface, nonNullStringBox2Interface} + }), + Fields: graphql.Fields{ + "scalar": &graphql.Field{ + Type: graphql.NewNonNull(graphql.String), + }, + "unrelatedField": &graphql.Field{ + Type: graphql.String, + }, + }, + }) + + var connectionObject = graphql.NewObject(graphql.ObjectConfig{ + Name: "Connection", + Fields: graphql.Fields{ + "edges": &graphql.Field{ + Type: graphql.NewList(graphql.NewObject(graphql.ObjectConfig{ + Name: "Edge", + Fields: graphql.Fields{ + "node": &graphql.Field{ + Type: graphql.NewObject(graphql.ObjectConfig{ + Name: "Node", + Fields: graphql.Fields{ + "id": &graphql.Field{ + Type: graphql.ID, + }, + "name": &graphql.Field{ + Type: graphql.String, + }, + }, + }), + }, + }, + })), + }, + }, + }) + var err error + schema, err = graphql.NewSchema(graphql.SchemaConfig{ + Query: graphql.NewObject(graphql.ObjectConfig{ + Name: "QueryRoot", + Fields: graphql.Fields{ + "someBox": &graphql.Field{ + Type: someBoxInterface, + }, + "connection": &graphql.Field{ + Type: connectionObject, + }, + }, + }), + }) + if err != nil { + panic(err) + } +} -func TestValidate_OverlappingFieldsCanBeMerged_ReturnTypesMustBeUnambiguous_ConflictingScalarReturnTypes(t *testing.T) { +func TestValidate_OverlappingFieldsCanBeMerged_ReturnTypesMustBeUnambiguous_ConflictingReturnTypesWhichPotentiallyOverlap(t *testing.T) { + // This is invalid since an object could potentially be both the Object + // type IntBox and the interface type NonNullStringBox1. While that + // condition does not exist in the current schema, the schema could + // expand in the future to allow this. Thus it is invalid. testutil.ExpectFailsRuleWithSchema(t, &schema, graphql.OverlappingFieldsCanBeMergedRule, ` { - boxUnion { + someBox { ...on IntBox { scalar } - ...on StringBox { + ...on NonNullStringBox1 { scalar } } } `, []gqlerrors.FormattedError{ testutil.RuleError( - `Fields "scalar" conflict because they return differing types Int and String.`, - 5, 15, 8, 15), + `Fields "scalar" conflict because they return differing types Int and String!.`, + 5, 15, + 8, 15), }) } +func TestValidate_OverlappingFieldsCanBeMerged_ReturnTypesMustBeUnambiguous_AllowsDiffereingReturnTypesWhichCannotOverlap(t *testing.T) { + // This is valid since an object cannot be both an IntBox and a StringBox. + testutil.ExpectPassesRuleWithSchema(t, &schema, graphql.OverlappingFieldsCanBeMergedRule, ` + { + someBox { + ...on IntBox { + scalar + } + ...on StringBox { + scalar + } + } + } + `) +} func TestValidate_OverlappingFieldsCanBeMerged_ReturnTypesMustBeUnambiguous_SameWrappedScalarReturnTypes(t *testing.T) { testutil.ExpectPassesRuleWithSchema(t, &schema, graphql.OverlappingFieldsCanBeMergedRule, ` { - boxUnion { + someBox { ...on NonNullStringBox1 { scalar } @@ -367,6 +475,16 @@ func TestValidate_OverlappingFieldsCanBeMerged_ReturnTypesMustBeUnambiguous_Same } `) } +func TestValidate_OverlappingFieldsCanBeMerged_ReturnTypesMustBeUnambiguous_AllowsInlineTypelessFragments(t *testing.T) { + testutil.ExpectPassesRuleWithSchema(t, &schema, graphql.OverlappingFieldsCanBeMergedRule, ` + { + a + ... { + a + } + } + `) +} func TestValidate_OverlappingFieldsCanBeMerged_ReturnTypesMustBeUnambiguous_ComparesDeepTypesIncludingList(t *testing.T) { testutil.ExpectFailsRuleWithSchema(t, &schema, graphql.OverlappingFieldsCanBeMergedRule, ` { @@ -391,13 +509,18 @@ func TestValidate_OverlappingFieldsCanBeMerged_ReturnTypesMustBeUnambiguous_Comp testutil.RuleError( `Fields "edges" conflict because subfields "node" conflict because subfields "id" conflict because `+ `id and name are different fields.`, - 14, 11, 5, 13, 15, 13, 6, 15, 16, 15, 7, 17), + 14, 11, + 15, 13, + 16, 15, + 5, 13, + 6, 15, + 7, 17), }) } func TestValidate_OverlappingFieldsCanBeMerged_ReturnTypesMustBeUnambiguous_IgnoresUnknownTypes(t *testing.T) { testutil.ExpectPassesRuleWithSchema(t, &schema, graphql.OverlappingFieldsCanBeMergedRule, ` { - boxUnion { + someBox { ...on UnknownType { scalar } diff --git a/rules_unique_argument_names_test.go b/rules_unique_argument_names_test.go index 2c111b80..b0e3ec51 100644 --- a/rules_unique_argument_names_test.go +++ b/rules_unique_argument_names_test.go @@ -18,7 +18,7 @@ func TestValidate_UniqueArgumentNames_NoArgumentsOnField(t *testing.T) { func TestValidate_UniqueArgumentNames_NoArgumentsOnDirective(t *testing.T) { testutil.ExpectPassesRule(t, graphql.UniqueArgumentNamesRule, ` { - field + field @directive } `) } diff --git a/rules_unique_input_field_names_test.go b/rules_unique_input_field_names_test.go new file mode 100644 index 00000000..a2e2e251 --- /dev/null +++ b/rules_unique_input_field_names_test.go @@ -0,0 +1,65 @@ +package graphql_test + +import ( + "testing" + + "github.com/graphql-go/graphql" + "github.com/graphql-go/graphql/gqlerrors" + "github.com/graphql-go/graphql/testutil" +) + +func TestValidate_UniqueInputFieldNames_InputObjectWithFields(t *testing.T) { + testutil.ExpectPassesRule(t, graphql.UniqueInputFieldNamesRule, ` + { + field(arg: { f: true }) + } + `) +} +func TestValidate_UniqueInputFieldNames_SameInputObjectWithinTwoArgs(t *testing.T) { + testutil.ExpectPassesRule(t, graphql.UniqueInputFieldNamesRule, ` + { + field(arg1: { f: true }, arg2: { f: true }) + } + `) +} +func TestValidate_UniqueInputFieldNames_MultipleInputObjectFields(t *testing.T) { + testutil.ExpectPassesRule(t, graphql.UniqueInputFieldNamesRule, ` + { + field(arg: { f1: "value", f2: "value", f3: "value" }) + } + `) +} +func TestValidate_UniqueInputFieldNames_AllowsForNestedInputObjectsWithSimilarFields(t *testing.T) { + testutil.ExpectPassesRule(t, graphql.UniqueInputFieldNamesRule, ` + { + field(arg: { + deep: { + deep: { + id: 1 + } + id: 1 + } + id: 1 + }) + } + `) +} +func TestValidate_UniqueInputFieldNames_DuplicateInputObjectFields(t *testing.T) { + testutil.ExpectFailsRule(t, graphql.UniqueInputFieldNamesRule, ` + { + field(arg: { f1: "value", f1: "value" }) + } + `, []gqlerrors.FormattedError{ + testutil.RuleError(`There can be only one input field named "f1".`, 3, 22, 3, 35), + }) +} +func TestValidate_UniqueInputFieldNames_ManyDuplicateInputObjectFields(t *testing.T) { + testutil.ExpectFailsRule(t, graphql.UniqueInputFieldNamesRule, ` + { + field(arg: { f1: "value", f1: "value", f1: "value" }) + } + `, []gqlerrors.FormattedError{ + testutil.RuleError(`There can be only one input field named "f1".`, 3, 22, 3, 35), + testutil.RuleError(`There can be only one input field named "f1".`, 3, 22, 3, 48), + }) +} diff --git a/rules_unique_operation_names_test.go b/rules_unique_operation_names_test.go index 7004819e..8903cdcb 100644 --- a/rules_unique_operation_names_test.go +++ b/rules_unique_operation_names_test.go @@ -49,6 +49,10 @@ func TestValidate_UniqueOperationNames_MultipleOperationsOfDifferentTypes(t *tes mutation Bar { field } + + subscription Baz { + field + } `) } func TestValidate_UniqueOperationNames_FragmentAndOperationNamedTheSame(t *testing.T) { @@ -73,7 +77,7 @@ func TestValidate_UniqueOperationNames_MultipleOperationsOfSameName(t *testing.T testutil.RuleError(`There can only be one operation named "Foo".`, 2, 13, 5, 13), }) } -func TestValidate_UniqueOperationNames_MultipleOperationsOfSameNameOfDifferentTypes(t *testing.T) { +func TestValidate_UniqueOperationNames_MultipleOperationsOfSameNameOfDifferentTypes_Mutation(t *testing.T) { testutil.ExpectFailsRule(t, graphql.UniqueOperationNamesRule, ` query Foo { fieldA @@ -85,3 +89,16 @@ func TestValidate_UniqueOperationNames_MultipleOperationsOfSameNameOfDifferentTy testutil.RuleError(`There can only be one operation named "Foo".`, 2, 13, 5, 16), }) } + +func TestValidate_UniqueOperationNames_MultipleOperationsOfSameNameOfDifferentTypes_Subscription(t *testing.T) { + testutil.ExpectFailsRule(t, graphql.UniqueOperationNamesRule, ` + query Foo { + fieldA + } + subscription Foo { + fieldB + } + `, []gqlerrors.FormattedError{ + testutil.RuleError(`There can only be one operation named "Foo".`, 2, 13, 5, 20), + }) +} diff --git a/rules_unique_variable_names_test.go b/rules_unique_variable_names_test.go new file mode 100644 index 00000000..63bf7778 --- /dev/null +++ b/rules_unique_variable_names_test.go @@ -0,0 +1,28 @@ +package graphql_test + +import ( + "testing" + + "github.com/graphql-go/graphql" + "github.com/graphql-go/graphql/gqlerrors" + "github.com/graphql-go/graphql/testutil" +) + +func TestValidate_UniqueVariableNames_UniqueVariableNames(t *testing.T) { + testutil.ExpectPassesRule(t, graphql.UniqueVariableNamesRule, ` + query A($x: Int, $y: String) { __typename } + query B($x: String, $y: Int) { __typename } + `) +} +func TestValidate_UniqueVariableNames_DuplicateVariableNames(t *testing.T) { + testutil.ExpectFailsRule(t, graphql.UniqueVariableNamesRule, ` + query A($x: Int, $x: Int, $x: String) { __typename } + query B($x: String, $x: Int) { __typename } + query C($x: Int, $x: Int) { __typename } + `, []gqlerrors.FormattedError{ + testutil.RuleError(`There can only be one variable named "x".`, 2, 16, 2, 25), + testutil.RuleError(`There can only be one variable named "x".`, 2, 16, 2, 34), + testutil.RuleError(`There can only be one variable named "x".`, 3, 16, 3, 28), + testutil.RuleError(`There can only be one variable named "x".`, 4, 16, 4, 25), + }) +} diff --git a/rules_variables_in_allowed_position_test.go b/rules_variables_in_allowed_position_test.go index 83ee2aa7..78dd77ea 100644 --- a/rules_variables_in_allowed_position_test.go +++ b/rules_variables_in_allowed_position_test.go @@ -121,7 +121,7 @@ func TestValidate_VariablesInAllowedPosition_ComplexInputToComplexInput(t *testi query Query($complexVar: ComplexInput) { complicatedArgs { - complexArgField(complexArg: $ComplexInput) + complexArgField(complexArg: $complexVar) } } `) @@ -154,15 +154,14 @@ func TestValidate_VariablesInAllowedPosition_NonNullableBooleanToNonNullableBool } func TestValidate_VariablesInAllowedPosition_IntToNonNullableInt(t *testing.T) { testutil.ExpectFailsRule(t, graphql.VariablesInAllowedPositionRule, ` - query Query($intArg: Int) - { + query Query($intArg: Int) { complicatedArgs { nonNullIntArgField(nonNullIntArg: $intArg) } } `, []gqlerrors.FormattedError{ testutil.RuleError(`Variable "$intArg" of type "Int" used in position `+ - `expecting type "Int!".`, 5, 45), + `expecting type "Int!".`, 2, 19, 4, 45), }) } func TestValidate_VariablesInAllowedPosition_IntToNonNullableIntWithinFragment(t *testing.T) { @@ -171,15 +170,14 @@ func TestValidate_VariablesInAllowedPosition_IntToNonNullableIntWithinFragment(t nonNullIntArgField(nonNullIntArg: $intArg) } - query Query($intArg: Int) - { + query Query($intArg: Int) { complicatedArgs { ...nonNullIntArgFieldFrag } } `, []gqlerrors.FormattedError{ testutil.RuleError(`Variable "$intArg" of type "Int" used in position `+ - `expecting type "Int!".`, 3, 43), + `expecting type "Int!".`, 6, 19, 3, 43), }) } func TestValidate_VariablesInAllowedPosition_IntToNonNullableIntWithinNestedFragment(t *testing.T) { @@ -192,62 +190,57 @@ func TestValidate_VariablesInAllowedPosition_IntToNonNullableIntWithinNestedFrag nonNullIntArgField(nonNullIntArg: $intArg) } - query Query($intArg: Int) - { + query Query($intArg: Int) { complicatedArgs { ...outerFrag } } `, []gqlerrors.FormattedError{ testutil.RuleError(`Variable "$intArg" of type "Int" used in position `+ - `expecting type "Int!".`, 7, 43), + `expecting type "Int!".`, 10, 19, 7, 43), }) } func TestValidate_VariablesInAllowedPosition_StringOverBoolean(t *testing.T) { testutil.ExpectFailsRule(t, graphql.VariablesInAllowedPositionRule, ` - query Query($stringVar: String) - { + query Query($stringVar: String) { complicatedArgs { booleanArgField(booleanArg: $stringVar) } } `, []gqlerrors.FormattedError{ testutil.RuleError(`Variable "$stringVar" of type "String" used in position `+ - `expecting type "Boolean".`, 5, 39), + `expecting type "Boolean".`, 2, 19, 4, 39), }) } func TestValidate_VariablesInAllowedPosition_StringToListOfString(t *testing.T) { testutil.ExpectFailsRule(t, graphql.VariablesInAllowedPositionRule, ` - query Query($stringVar: String) - { + query Query($stringVar: String) { complicatedArgs { stringListArgField(stringListArg: $stringVar) } } `, []gqlerrors.FormattedError{ testutil.RuleError(`Variable "$stringVar" of type "String" used in position `+ - `expecting type "[String]".`, 5, 45), + `expecting type "[String]".`, 2, 19, 4, 45), }) } func TestValidate_VariablesInAllowedPosition_BooleanToNonNullableBooleanInDirective(t *testing.T) { testutil.ExpectFailsRule(t, graphql.VariablesInAllowedPositionRule, ` - query Query($boolVar: Boolean) - { + query Query($boolVar: Boolean) { dog @include(if: $boolVar) } `, []gqlerrors.FormattedError{ testutil.RuleError(`Variable "$boolVar" of type "Boolean" used in position `+ - `expecting type "Boolean!".`, 4, 26), + `expecting type "Boolean!".`, 2, 19, 3, 26), }) } func TestValidate_VariablesInAllowedPosition_StringToNonNullableBooleanInDirective(t *testing.T) { testutil.ExpectFailsRule(t, graphql.VariablesInAllowedPositionRule, ` - query Query($stringVar: String) - { + query Query($stringVar: String) { dog @include(if: $stringVar) } `, []gqlerrors.FormattedError{ testutil.RuleError(`Variable "$stringVar" of type "String" used in position `+ - `expecting type "Boolean!".`, 4, 26), + `expecting type "Boolean!".`, 2, 19, 3, 26), }) } diff --git a/scalars.go b/scalars.go index 14782369..77842a1f 100644 --- a/scalars.go +++ b/scalars.go @@ -8,6 +8,11 @@ import ( "github.com/graphql-go/graphql/language/ast" ) +// As per the GraphQL Spec, Integers are only treated as valid when a valid +// 32-bit signed integer, providing the broadest support across platforms. +// +// n.b. JavaScript's integers are safe between -(2^53 - 1) and 2^53 - 1 because +// they are internally represented as IEEE 754 doubles. func coerceInt(value interface{}) interface{} { switch value := value.(type) { case bool: @@ -16,6 +21,9 @@ func coerceInt(value interface{}) interface{} { } return 0 case int: + if value < int(math.MinInt32) || value > int(math.MaxInt32) { + return nil + } return value case int8: return int(value) @@ -29,6 +37,9 @@ func coerceInt(value interface{}) interface{} { } return int(value) case uint: + if value > math.MaxInt32 { + return nil + } return int(value) case uint8: return int(value) @@ -68,8 +79,10 @@ func coerceInt(value interface{}) interface{} { } // Int is the GraphQL Integer type definition. -var Int *Scalar = NewScalar(ScalarConfig{ - Name: "Int", +var Int = NewScalar(ScalarConfig{ + Name: "Int", + Description: "The `Int` scalar type represents non-fractional signed whole numeric " + + "values. Int can represent values between -(2^31) and 2^31 - 1. ", Serialize: coerceInt, ParseValue: coerceInt, ParseLiteral: func(valueAST ast.Value) interface{} { @@ -107,8 +120,11 @@ func coerceFloat32(value interface{}) interface{} { } // Float is the GraphQL float type definition. -var Float *Scalar = NewScalar(ScalarConfig{ - Name: "Float", +var Float = NewScalar(ScalarConfig{ + Name: "Float", + Description: "The `Float` scalar type represents signed double-precision fractional " + + "values as specified by " + + "[IEEE 754](http://en.wikipedia.org/wiki/IEEE_floating_point). ", Serialize: coerceFloat32, ParseValue: coerceFloat32, ParseLiteral: func(valueAST ast.Value) interface{} { @@ -131,8 +147,11 @@ func coerceString(value interface{}) interface{} { } // String is the GraphQL string type definition -var String *Scalar = NewScalar(ScalarConfig{ - Name: "String", +var String = NewScalar(ScalarConfig{ + Name: "String", + Description: "The `String` scalar type represents textual data, represented as UTF-8 " + + "character sequences. The String type is most often used by GraphQL to " + + "represent free-form human-readable text.", Serialize: coerceString, ParseValue: coerceString, ParseLiteral: func(valueAST ast.Value) interface{} { @@ -174,10 +193,11 @@ func coerceBool(value interface{}) interface{} { } // Boolean is the GraphQL boolean type definition -var Boolean *Scalar = NewScalar(ScalarConfig{ - Name: "Boolean", - Serialize: coerceBool, - ParseValue: coerceBool, +var Boolean = NewScalar(ScalarConfig{ + Name: "Boolean", + Description: "The `Boolean` scalar type represents `true` or `false`.", + Serialize: coerceBool, + ParseValue: coerceBool, ParseLiteral: func(valueAST ast.Value) interface{} { switch valueAST := valueAST.(type) { case *ast.BooleanValue: @@ -188,8 +208,13 @@ var Boolean *Scalar = NewScalar(ScalarConfig{ }) // ID is the GraphQL id type definition -var ID *Scalar = NewScalar(ScalarConfig{ - Name: "ID", +var ID = NewScalar(ScalarConfig{ + Name: "ID", + Description: "The `ID` scalar type represents a unique identifier, often used to " + + "refetch an object or as key for a cache. The ID type appears in a JSON " + + "response as a String; however, it is not intended to be human-readable. " + + "When expected as an input type, any string (such as `\"4\"`) or integer " + + "(such as `4`) input value will be accepted as an ID.", Serialize: coerceString, ParseValue: coerceString, ParseLiteral: func(valueAST ast.Value) interface{} { diff --git a/scalars_serialization_test.go b/scalars_serialization_test.go index 96c5ff4d..4dbe2488 100644 --- a/scalars_serialization_test.go +++ b/scalars_serialization_test.go @@ -35,11 +35,13 @@ func TestTypeSystem_Scalar_SerializesOutputInt(t *testing.T) { {float32(0.1), 0}, {float32(1.1), 1}, {float32(-1.1), -1}, - // Bigger than 2^32, but still representable as an Int {float32(1e5), 100000}, {float32(math.MaxFloat32), nil}, - {9876504321, 9876504321}, - {-9876504321, -9876504321}, + // Maybe a safe Go/Javascript `int`, but bigger than 2^32, so not + // representable as a GraphQL Int + {9876504321, nil}, + {-9876504321, nil}, + // Too big to represent as an Int in Go, JavaScript or GraphQL {float64(1e100), nil}, {float64(-1e100), nil}, {"-1.1", -1}, @@ -51,6 +53,9 @@ func TestTypeSystem_Scalar_SerializesOutputInt(t *testing.T) { {int32(1), 1}, {int64(1), 1}, {uint(1), 1}, + // Maybe a safe Go `uint`, but bigger than 2^32, so not + // representable as a GraphQL Int + {uint(math.MaxInt32 + 1), nil}, {uint8(1), 1}, {uint16(1), 1}, {uint32(1), 1}, diff --git a/schema.go b/schema.go index b2be7bba..108cdbac 100644 --- a/schema.go +++ b/schema.go @@ -4,29 +4,32 @@ import ( "fmt" ) -/** -Schema Definition -A Schema is created by supplying the root types of each type of operation, -query and mutation (optional). A schema definition is then supplied to the -validator and executor. -Example: - myAppSchema, err := NewSchema(SchemaConfig({ - Query: MyAppQueryRootType - Mutation: MyAppMutationRootType - }); -*/ type SchemaConfig struct { - Query *Object - Mutation *Object + Query *Object + Mutation *Object + Subscription *Object + Directives []*Directive } -// chose to name as TypeMap instead of TypeMap type TypeMap map[string]Type +//Schema Definition +//A Schema is created by supplying the root types of each type of operation, +//query, mutation (optional) and subscription (optional). A schema definition is then supplied to the +//validator and executor. +//Example: +// myAppSchema, err := NewSchema(SchemaConfig({ +// Query: MyAppQueryRootType, +// Mutation: MyAppMutationRootType, +// Subscription: MyAppSubscriptionRootType, +// }); type Schema struct { - schemaConfig SchemaConfig - typeMap TypeMap - directives []*Directive + typeMap TypeMap + directives []*Directive + + queryType *Object + mutationType *Object + subscriptionType *Object } func NewSchema(config SchemaConfig) (Schema, error) { @@ -47,15 +50,27 @@ func NewSchema(config SchemaConfig) (Schema, error) { return schema, config.Mutation.err } - schema.schemaConfig = config + schema.queryType = config.Query + schema.mutationType = config.Mutation + schema.subscriptionType = config.Subscription + + // Provide `@include() and `@skip()` directives by default. + schema.directives = config.Directives + if len(schema.directives) == 0 { + schema.directives = []*Directive{ + IncludeDirective, + SkipDirective, + } + } // Build type map now to detect any errors within this schema. typeMap := TypeMap{} objectTypes := []*Object{ schema.QueryType(), schema.MutationType(), - __Type, - __Schema, + schema.SubscriptionType(), + typeType, + schemaType, } for _, objectType := range objectTypes { if objectType == nil { @@ -87,20 +102,18 @@ func NewSchema(config SchemaConfig) (Schema, error) { } func (gq *Schema) QueryType() *Object { - return gq.schemaConfig.Query + return gq.queryType } func (gq *Schema) MutationType() *Object { - return gq.schemaConfig.Mutation + return gq.mutationType +} + +func (gq *Schema) SubscriptionType() *Object { + return gq.subscriptionType } func (gq *Schema) Directives() []*Directive { - if len(gq.directives) == 0 { - gq.directives = []*Directive{ - IncludeDirective, - SkipDirective, - } - } return gq.directives } @@ -258,7 +271,7 @@ func assertObjectImplementsInterface(object *Object, iface *Interface) error { ifaceFieldMap := iface.Fields() // Assert each interface field is implemented. - for fieldName, _ := range ifaceFieldMap { + for fieldName := range ifaceFieldMap { objectField := objectFieldMap[fieldName] ifaceField := ifaceFieldMap[fieldName] @@ -272,9 +285,10 @@ func assertObjectImplementsInterface(object *Object, iface *Interface) error { return err } - // Assert interface field type matches object field type. (invariant) + // Assert interface field type is satisfied by object field type, by being + // a valid subtype. (covariant) err = invariant( - isEqualType(ifaceField.Type, objectField.Type), + isTypeSubTypeOf(objectField.Type, ifaceField.Type), fmt.Sprintf(`%v.%v expects type "%v" but `+ `%v.%v provides type "%v".`, iface, fieldName, ifaceField.Type, @@ -321,7 +335,7 @@ func assertObjectImplementsInterface(object *Object, iface *Interface) error { return err } } - // Assert argument set invariance. + // Assert additional arguments must not be required. for _, objectArg := range objectField.Args { argName := objectArg.PrivateName var ifaceArg *Argument @@ -331,15 +345,19 @@ func assertObjectImplementsInterface(object *Object, iface *Interface) error { break } } - err = invariant( - ifaceArg != nil, - fmt.Sprintf(`%v.%v does not define argument "%v" but `+ - `%v.%v provides it.`, - iface, fieldName, argName, - object, fieldName), - ) - if err != nil { - return err + + if ifaceArg == nil { + _, ok := objectArg.Type.(*NonNull) + err = invariant( + !ok, + fmt.Sprintf(`%v.%v(%v:) is of required type `+ + `"%v" but is not also provided by the interface %v.%v.`, + object, fieldName, argName, + objectArg.Type, iface, fieldName), + ) + if err != nil { + return err + } } } } @@ -347,15 +365,72 @@ func assertObjectImplementsInterface(object *Object, iface *Interface) error { } func isEqualType(typeA Type, typeB Type) bool { + // Equivalent type is a valid subtype + if typeA == typeB { + return true + } + // If either type is non-null, the other must also be non-null. if typeA, ok := typeA.(*NonNull); ok { if typeB, ok := typeB.(*NonNull); ok { return isEqualType(typeA.OfType, typeB.OfType) } } + // If either type is a list, the other must also be a list. if typeA, ok := typeA.(*List); ok { if typeB, ok := typeB.(*List); ok { return isEqualType(typeA.OfType, typeB.OfType) } } - return typeA == typeB + // Otherwise the types are not equal. + return false +} + +/** + * Provided a type and a super type, return true if the first type is either + * equal or a subset of the second super type (covariant). + */ +func isTypeSubTypeOf(maybeSubType Type, superType Type) bool { + // Equivalent type is a valid subtype + if maybeSubType == superType { + return true + } + + // If superType is non-null, maybeSubType must also be nullable. + if superType, ok := superType.(*NonNull); ok { + if maybeSubType, ok := maybeSubType.(*NonNull); ok { + return isTypeSubTypeOf(maybeSubType.OfType, superType.OfType) + } + return false + } + if maybeSubType, ok := maybeSubType.(*NonNull); ok { + // If superType is nullable, maybeSubType may be non-null. + return isTypeSubTypeOf(maybeSubType.OfType, superType) + } + + // If superType type is a list, maybeSubType type must also be a list. + if superType, ok := superType.(*List); ok { + if maybeSubType, ok := maybeSubType.(*List); ok { + return isTypeSubTypeOf(maybeSubType.OfType, superType.OfType) + } + return false + } else if _, ok := maybeSubType.(*List); ok { + // If superType is not a list, maybeSubType must also be not a list. + return false + } + + // If superType type is an abstract type, maybeSubType type may be a currently + // possible object type. + if superType, ok := superType.(*Interface); ok { + if maybeSubType, ok := maybeSubType.(*Object); ok && superType.IsPossibleType(maybeSubType) { + return true + } + } + if superType, ok := superType.(*Union); ok { + if maybeSubType, ok := maybeSubType.(*Object); ok && superType.IsPossibleType(maybeSubType) { + return true + } + } + + // Otherwise, the child type is not a valid subtype of the parent type. + return false } diff --git a/testutil/introspection_query.go b/testutil/introspection_query.go index 9d336353..555ad9df 100644 --- a/testutil/introspection_query.go +++ b/testutil/introspection_query.go @@ -5,6 +5,7 @@ var IntrospectionQuery = ` __schema { queryType { name } mutationType { name } + subscriptionType { name } types { ...FullType } @@ -25,7 +26,7 @@ var IntrospectionQuery = ` kind name description - fields { + fields(includeDeprecated: true) { name description args { @@ -43,7 +44,7 @@ var IntrospectionQuery = ` interfaces { ...TypeRef } - enumValues { + enumValues(includeDeprecated: true) { name description isDeprecated diff --git a/testutil/rules_test_harness.go b/testutil/rules_test_harness.go index 3acb9094..3691f6c4 100644 --- a/testutil/rules_test_harness.go +++ b/testutil/rules_test_harness.go @@ -11,7 +11,7 @@ import ( "reflect" ) -var defaultRulesTestSchema *graphql.Schema +var TestSchema *graphql.Schema func init() { @@ -41,6 +41,19 @@ func init() { }, }, }) + var canineInterface = graphql.NewInterface(graphql.InterfaceConfig{ + Name: "Canine", + Fields: graphql.Fields{ + "name": &graphql.Field{ + Type: graphql.String, + Args: graphql.FieldConfigArgument{ + "surname": &graphql.ArgumentConfig{ + Type: graphql.Boolean, + }, + }, + }, + }, + }) var dogCommandEnum = graphql.NewEnum(graphql.EnumConfig{ Name: "DogCommand", Values: graphql.EnumValueConfigMap{ @@ -110,6 +123,7 @@ func init() { Interfaces: []*graphql.Interface{ beingInterface, petInterface, + canineInterface, }, }) var furColorEnum = graphql.NewEnum(graphql.EnumConfig{ @@ -444,11 +458,19 @@ func init() { }) schema, err := graphql.NewSchema(graphql.SchemaConfig{ Query: queryRoot, + Directives: []*graphql.Directive{ + graphql.NewDirective(&graphql.Directive{ + Name: "operationOnly", + OnOperation: true, + }), + graphql.IncludeDirective, + graphql.SkipDirective, + }, }) if err != nil { panic(err) } - defaultRulesTestSchema = &schema + TestSchema = &schema } func expectValidRule(t *testing.T, schema *graphql.Schema, rules []graphql.ValidationRuleFn, queryString string) { @@ -498,10 +520,10 @@ func expectInvalidRule(t *testing.T, schema *graphql.Schema, rules []graphql.Val } func ExpectPassesRule(t *testing.T, rule graphql.ValidationRuleFn, queryString string) { - expectValidRule(t, defaultRulesTestSchema, []graphql.ValidationRuleFn{rule}, queryString) + expectValidRule(t, TestSchema, []graphql.ValidationRuleFn{rule}, queryString) } func ExpectFailsRule(t *testing.T, rule graphql.ValidationRuleFn, queryString string, expectedErrors []gqlerrors.FormattedError) { - expectInvalidRule(t, defaultRulesTestSchema, []graphql.ValidationRuleFn{rule}, queryString, expectedErrors) + expectInvalidRule(t, TestSchema, []graphql.ValidationRuleFn{rule}, queryString, expectedErrors) } func ExpectFailsRuleWithSchema(t *testing.T, schema *graphql.Schema, rule graphql.ValidationRuleFn, queryString string, expectedErrors []gqlerrors.FormattedError) { expectInvalidRule(t, schema, []graphql.ValidationRuleFn{rule}, queryString, expectedErrors) diff --git a/testutil/testutil.go b/testutil/testutil.go index 9077a154..9951d8f3 100644 --- a/testutil/testutil.go +++ b/testutil/testutil.go @@ -29,7 +29,7 @@ var ( ) type StarWarsChar struct { - Id string + ID string Name string Friends []StarWarsChar AppearsIn []int @@ -39,41 +39,41 @@ type StarWarsChar struct { func init() { Luke = StarWarsChar{ - Id: "1000", + ID: "1000", Name: "Luke Skywalker", AppearsIn: []int{4, 5, 6}, HomePlanet: "Tatooine", } Vader = StarWarsChar{ - Id: "1001", + ID: "1001", Name: "Darth Vader", AppearsIn: []int{4, 5, 6}, HomePlanet: "Tatooine", } Han = StarWarsChar{ - Id: "1002", + ID: "1002", Name: "Han Solo", AppearsIn: []int{4, 5, 6}, } Leia = StarWarsChar{ - Id: "1003", + ID: "1003", Name: "Leia Organa", AppearsIn: []int{4, 5, 6}, HomePlanet: "Alderaa", } Tarkin = StarWarsChar{ - Id: "1004", + ID: "1004", Name: "Wilhuff Tarkin", AppearsIn: []int{4}, } Threepio = StarWarsChar{ - Id: "2000", + ID: "2000", Name: "C-3PO", AppearsIn: []int{4, 5, 6}, PrimaryFunction: "Protocol", } Artoo = StarWarsChar{ - Id: "2001", + ID: "2001", Name: "R2-D2", AppearsIn: []int{4, 5, 6}, PrimaryFunction: "Astromech", @@ -135,9 +135,9 @@ func init() { }, ResolveType: func(value interface{}, info graphql.ResolveInfo) *graphql.Object { if character, ok := value.(StarWarsChar); ok { - id, _ := strconv.Atoi(character.Id) + id, _ := strconv.Atoi(character.ID) human := GetHuman(id) - if human.Id != "" { + if human.ID != "" { return humanType } } @@ -158,7 +158,7 @@ func init() { Description: "The id of the human.", Resolve: func(p graphql.ResolveParams) (interface{}, error) { if human, ok := p.Source.(StarWarsChar); ok { - return human.Id, nil + return human.ID, nil } return nil, nil }, @@ -217,7 +217,7 @@ func init() { Description: "The id of the droid.", Resolve: func(p graphql.ResolveParams) (interface{}, error) { if droid, ok := p.Source.(StarWarsChar); ok { - return droid.Id, nil + return droid.ID, nil } return nil, nil }, @@ -241,7 +241,7 @@ func init() { for _, friend := range droid.Friends { friends = append(friends, map[string]interface{}{ "name": friend.Name, - "id": friend.Id, + "id": friend.ID, }) } return droid.Friends, nil @@ -418,9 +418,8 @@ subLoop: } if !found { return false - } else { - continue subLoop } + continue subLoop } return true } diff --git a/type_info.go b/type_info.go index e7978889..3c06c29b 100644 --- a/type_info.go +++ b/type_info.go @@ -11,6 +11,8 @@ import ( * of the current field and type definitions at any point in a GraphQL document * AST during a recursive descent by calling `enter(node)` and `leave(node)`. */ +type fieldDefFn func(schema *Schema, parentType Type, fieldAST *ast.Field) *FieldDefinition + type TypeInfo struct { schema *Schema typeStack []Output @@ -19,11 +21,26 @@ type TypeInfo struct { fieldDefStack []*FieldDefinition directive *Directive argument *Argument + getFieldDef fieldDefFn } -func NewTypeInfo(schema *Schema) *TypeInfo { +type TypeInfoConfig struct { + Schema *Schema + + // NOTE: this experimental optional second parameter is only needed in order + // to support non-spec-compliant codebases. You should never need to use it. + // It may disappear in the future. + FieldDefFn fieldDefFn +} + +func NewTypeInfo(opts *TypeInfoConfig) *TypeInfo { + getFieldDef := opts.FieldDefFn + if getFieldDef == nil { + getFieldDef = DefaultTypeInfoFieldDef + } return &TypeInfo{ - schema: schema, + schema: opts.Schema, + getFieldDef: getFieldDef, } } @@ -69,7 +86,7 @@ func (ti *TypeInfo) Enter(node ast.Node) { switch node := node.(type) { case *ast.SelectionSet: namedType := GetNamed(ti.Type()) - var compositeType Composite = nil + var compositeType Composite if IsCompositeType(namedType) { compositeType, _ = namedType.(Composite) } @@ -78,7 +95,7 @@ func (ti *TypeInfo) Enter(node ast.Node) { parentType := ti.ParentType() var fieldDef *FieldDefinition if parentType != nil { - fieldDef = TypeInfoFieldDef(*schema, parentType.(Type), node) + fieldDef = ti.getFieldDef(schema, parentType.(Type), node) } ti.fieldDefStack = append(ti.fieldDefStack, fieldDef) if fieldDef != nil { @@ -97,14 +114,26 @@ func (ti *TypeInfo) Enter(node ast.Node) { ttype = schema.QueryType() } else if node.Operation == "mutation" { ttype = schema.MutationType() + } else if node.Operation == "subscription" { + ttype = schema.SubscriptionType() } ti.typeStack = append(ti.typeStack, ttype) case *ast.InlineFragment: - ttype, _ = typeFromAST(*schema, node.TypeCondition) - ti.typeStack = append(ti.typeStack, ttype) + typeConditionAST := node.TypeCondition + if typeConditionAST != nil { + ttype, _ = typeFromAST(*schema, node.TypeCondition) + ti.typeStack = append(ti.typeStack, ttype) + } else { + ti.typeStack = append(ti.typeStack, ti.Type()) + } case *ast.FragmentDefinition: - ttype, _ = typeFromAST(*schema, node.TypeCondition) - ti.typeStack = append(ti.typeStack, ttype) + typeConditionAST := node.TypeCondition + if typeConditionAST != nil { + ttype, _ = typeFromAST(*schema, typeConditionAST) + ti.typeStack = append(ti.typeStack, ttype) + } else { + ti.typeStack = append(ti.typeStack, ti.Type()) + } case *ast.VariableDefinition: ttype, _ = typeFromAST(*schema, node.Type) ti.inputTypeStack = append(ti.inputTypeStack, ttype) @@ -163,12 +192,18 @@ func (ti *TypeInfo) Leave(node ast.Node) { switch kind { case kinds.SelectionSet: // pop ti.parentTypeStack - _, ti.parentTypeStack = ti.parentTypeStack[len(ti.parentTypeStack)-1], ti.parentTypeStack[:len(ti.parentTypeStack)-1] + if len(ti.parentTypeStack) > 0 { + _, ti.parentTypeStack = ti.parentTypeStack[len(ti.parentTypeStack)-1], ti.parentTypeStack[:len(ti.parentTypeStack)-1] + } case kinds.Field: // pop ti.fieldDefStack - _, ti.fieldDefStack = ti.fieldDefStack[len(ti.fieldDefStack)-1], ti.fieldDefStack[:len(ti.fieldDefStack)-1] + if len(ti.fieldDefStack) > 0 { + _, ti.fieldDefStack = ti.fieldDefStack[len(ti.fieldDefStack)-1], ti.fieldDefStack[:len(ti.fieldDefStack)-1] + } // pop ti.typeStack - _, ti.typeStack = ti.typeStack[len(ti.typeStack)-1], ti.typeStack[:len(ti.typeStack)-1] + if len(ti.typeStack) > 0 { + _, ti.typeStack = ti.typeStack[len(ti.typeStack)-1], ti.typeStack[:len(ti.typeStack)-1] + } case kinds.Directive: ti.directive = nil case kinds.OperationDefinition: @@ -177,28 +212,34 @@ func (ti *TypeInfo) Leave(node ast.Node) { fallthrough case kinds.FragmentDefinition: // pop ti.typeStack - _, ti.typeStack = ti.typeStack[len(ti.typeStack)-1], ti.typeStack[:len(ti.typeStack)-1] + if len(ti.typeStack) > 0 { + _, ti.typeStack = ti.typeStack[len(ti.typeStack)-1], ti.typeStack[:len(ti.typeStack)-1] + } case kinds.VariableDefinition: // pop ti.inputTypeStack - _, ti.inputTypeStack = ti.inputTypeStack[len(ti.inputTypeStack)-1], ti.inputTypeStack[:len(ti.inputTypeStack)-1] + if len(ti.inputTypeStack) > 0 { + _, ti.inputTypeStack = ti.inputTypeStack[len(ti.inputTypeStack)-1], ti.inputTypeStack[:len(ti.inputTypeStack)-1] + } case kinds.Argument: ti.argument = nil // pop ti.inputTypeStack - _, ti.inputTypeStack = ti.inputTypeStack[len(ti.inputTypeStack)-1], ti.inputTypeStack[:len(ti.inputTypeStack)-1] + if len(ti.inputTypeStack) > 0 { + _, ti.inputTypeStack = ti.inputTypeStack[len(ti.inputTypeStack)-1], ti.inputTypeStack[:len(ti.inputTypeStack)-1] + } case kinds.ListValue: fallthrough case kinds.ObjectField: // pop ti.inputTypeStack - _, ti.inputTypeStack = ti.inputTypeStack[len(ti.inputTypeStack)-1], ti.inputTypeStack[:len(ti.inputTypeStack)-1] + if len(ti.inputTypeStack) > 0 { + _, ti.inputTypeStack = ti.inputTypeStack[len(ti.inputTypeStack)-1], ti.inputTypeStack[:len(ti.inputTypeStack)-1] + } } } -/** - * Not exactly the same as the executor's definition of FieldDef, in this - * statically evaluated environment we do not always have an Object type, - * and need to handle Interface and Union types. - */ -func TypeInfoFieldDef(schema Schema, parentType Type, fieldAST *ast.Field) *FieldDefinition { +// DefaultTypeInfoFieldDef Not exactly the same as the executor's definition of FieldDef, in this +// statically evaluated environment we do not always have an Object type, +// and need to handle Interface and Union types. +func DefaultTypeInfoFieldDef(schema *Schema, parentType Type, fieldAST *ast.Field) *FieldDefinition { name := "" if fieldAST.Name != nil { name = fieldAST.Name.Value diff --git a/union_interface_test.go b/union_interface_test.go index ce8ac8c2..2da90f40 100644 --- a/union_interface_test.go +++ b/union_interface_test.go @@ -497,9 +497,6 @@ func TestUnionIntersectionTypes_AllowsFragmentConditionsToBeAbstractTypes(t *tes } func TestUnionIntersectionTypes_GetsExecutionInfoInResolver(t *testing.T) { - //var encounteredSchema *graphql.Schema - //var encounteredRootValue interface{} - var personType2 *graphql.Object namedType2 := graphql.NewInterface(graphql.InterfaceConfig{ @@ -510,8 +507,6 @@ func TestUnionIntersectionTypes_GetsExecutionInfoInResolver(t *testing.T) { }, }, ResolveType: func(value interface{}, info graphql.ResolveInfo) *graphql.Object { - //encounteredSchema = &info.Schema - //encounteredRootValue = info.RootValue return personType2 }, }) diff --git a/validation_test.go b/validation_test.go index 42fde9cc..85d2c98d 100644 --- a/validation_test.go +++ b/validation_test.go @@ -294,6 +294,23 @@ func TestTypeSystem_SchemaMustHaveObjectRootTypes_AcceptsASchemaWhoseQueryAndMut t.Fatalf("unexpected error: %v", err) } } +func TestTypeSystem_SchemaMustHaveObjectRootTypes_AcceptsASchemaWhoseQueryAndSubscriptionTypesAreObjectType(t *testing.T) { + subscriptionType := graphql.NewObject(graphql.ObjectConfig{ + Name: "Subscription", + Fields: graphql.Fields{ + "subscribe": &graphql.Field{ + Type: graphql.String, + }, + }, + }) + _, err := graphql.NewSchema(graphql.SchemaConfig{ + Query: someObjectType, + Mutation: subscriptionType, + }) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } +} func TestTypeSystem_SchemaMustHaveObjectRootTypes_RejectsASchemaWithoutAQueryType(t *testing.T) { _, err := graphql.NewSchema(graphql.SchemaConfig{}) expectedError := "Schema query must be Object Type but got: nil." @@ -1175,7 +1192,7 @@ func TestTypeSystem_ObjectsMustAdhereToInterfaceTheyImplement_AcceptsAnObjectWhi t.Fatalf(`unexpected error: %v for type "%v"`, err, anotherObject) } } -func TestTypeSystem_ObjectsMustAdhereToInterfaceTheyImplement_RejectsAnObjectWhichImplementsAnInterfaceFieldAlongWithMoreArguments(t *testing.T) { +func TestTypeSystem_ObjectsMustAdhereToInterfaceTheyImplement_AcceptsAnObjectWhichImpementsAnInterfaceFieldAlongWithAdditionalOptionalArguments(t *testing.T) { anotherInterface := graphql.NewInterface(graphql.InterfaceConfig{ Name: "AnotherInterface", ResolveType: func(value interface{}, info graphql.ResolveInfo) *graphql.Object { @@ -1210,7 +1227,46 @@ func TestTypeSystem_ObjectsMustAdhereToInterfaceTheyImplement_RejectsAnObjectWhi }, }) _, err := schemaWithObjectFieldOfType(anotherObject) - expectedError := `AnotherInterface.field does not define argument "anotherInput" but AnotherObject.field provides it.` + if err != nil { + t.Fatalf(`unexpected error: %v for type "%v"`, err, anotherObject) + } +} +func TestTypeSystem_ObjectsMustAdhereToInterfaceTheyImplement_RejectsAnObjectWhichImplementsAnInterfaceFieldAlongWithAdditionalRequiredArguments(t *testing.T) { + anotherInterface := graphql.NewInterface(graphql.InterfaceConfig{ + Name: "AnotherInterface", + ResolveType: func(value interface{}, info graphql.ResolveInfo) *graphql.Object { + return nil + }, + Fields: graphql.Fields{ + "field": &graphql.Field{ + Type: graphql.String, + Args: graphql.FieldConfigArgument{ + "input": &graphql.ArgumentConfig{ + Type: graphql.String, + }, + }, + }, + }, + }) + anotherObject := graphql.NewObject(graphql.ObjectConfig{ + Name: "AnotherObject", + Interfaces: []*graphql.Interface{anotherInterface}, + Fields: graphql.Fields{ + "field": &graphql.Field{ + Type: graphql.String, + Args: graphql.FieldConfigArgument{ + "input": &graphql.ArgumentConfig{ + Type: graphql.String, + }, + "anotherInput": &graphql.ArgumentConfig{ + Type: graphql.NewNonNull(graphql.String), + }, + }, + }, + }, + }) + _, err := schemaWithObjectFieldOfType(anotherObject) + expectedError := `AnotherObject.field(anotherInput:) is of required type "String!" but is not also provided by the interface AnotherInterface.field.` if err == nil || err.Error() != expectedError { t.Fatalf("Expected error: %v, got %v", expectedError, err) } @@ -1247,6 +1303,7 @@ func TestTypeSystem_ObjectsMustAdhereToInterfaceTheyImplement_RejectsAnObjectMis t.Fatalf("Expected error: %v, got %v", expectedError, err) } } + func TestTypeSystem_ObjectsMustAdhereToInterfaceTheyImplement_RejectsAnObjectWithAnIncorrectlyTypedInterfaceField(t *testing.T) { anotherInterface := graphql.NewInterface(graphql.InterfaceConfig{ Name: "AnotherInterface", @@ -1256,11 +1313,6 @@ func TestTypeSystem_ObjectsMustAdhereToInterfaceTheyImplement_RejectsAnObjectWit Fields: graphql.Fields{ "field": &graphql.Field{ Type: graphql.String, - Args: graphql.FieldConfigArgument{ - "input": &graphql.ArgumentConfig{ - Type: graphql.String, - }, - }, }, }, }) @@ -1270,11 +1322,6 @@ func TestTypeSystem_ObjectsMustAdhereToInterfaceTheyImplement_RejectsAnObjectWit Fields: graphql.Fields{ "field": &graphql.Field{ Type: someScalarType, - Args: graphql.FieldConfigArgument{ - "input": &graphql.ArgumentConfig{ - Type: graphql.String, - }, - }, }, }, }) @@ -1284,6 +1331,111 @@ func TestTypeSystem_ObjectsMustAdhereToInterfaceTheyImplement_RejectsAnObjectWit t.Fatalf("Expected error: %v, got %v", expectedError, err) } } + +func TestTypeSystem_ObjectsMustAdhereToInterfaceTheyImplement_RejectsAnObjectWithADifferentlyTypeInterfaceField(t *testing.T) { + + typeA := graphql.NewObject(graphql.ObjectConfig{ + Name: "A", + Fields: graphql.Fields{ + "foo": &graphql.Field{ + Type: graphql.String, + }, + }, + }) + typeB := graphql.NewObject(graphql.ObjectConfig{ + Name: "B", + Fields: graphql.Fields{ + "foo": &graphql.Field{ + Type: graphql.String, + }, + }, + }) + + anotherInterface := graphql.NewInterface(graphql.InterfaceConfig{ + Name: "AnotherInterface", + ResolveType: func(value interface{}, info graphql.ResolveInfo) *graphql.Object { + return nil + }, + Fields: graphql.Fields{ + "field": &graphql.Field{ + Type: typeA, + }, + }, + }) + anotherObject := graphql.NewObject(graphql.ObjectConfig{ + Name: "AnotherObject", + Interfaces: []*graphql.Interface{anotherInterface}, + Fields: graphql.Fields{ + "field": &graphql.Field{ + Type: typeB, + }, + }, + }) + _, err := schemaWithObjectFieldOfType(anotherObject) + expectedError := `AnotherInterface.field expects type "A" but AnotherObject.field provides type "B".` + if err == nil || err.Error() != expectedError { + t.Fatalf("Expected error: %v, got %v", expectedError, err) + } +} + +func TestTypeSystem_ObjectsMustAdhereToInterfaceTheyImplement_AcceptsAnObjectWithASubtypedInterfaceField_Interface(t *testing.T) { + var anotherInterface *graphql.Interface + anotherInterface = graphql.NewInterface(graphql.InterfaceConfig{ + Name: "AnotherInterface", + ResolveType: func(value interface{}, info graphql.ResolveInfo) *graphql.Object { + return nil + }, + Fields: (graphql.FieldsThunk)(func() graphql.Fields { + return graphql.Fields{ + "field": &graphql.Field{ + Type: anotherInterface, + }, + } + }), + }) + var anotherObject *graphql.Object + anotherObject = graphql.NewObject(graphql.ObjectConfig{ + Name: "AnotherObject", + Interfaces: []*graphql.Interface{anotherInterface}, + Fields: (graphql.FieldsThunk)(func() graphql.Fields { + return graphql.Fields{ + "field": &graphql.Field{ + Type: anotherObject, + }, + } + }), + }) + _, err := schemaWithFieldType(anotherObject) + if err != nil { + t.Fatalf(`unexpected error: %v for type "%v"`, err, anotherObject) + } +} +func TestTypeSystem_ObjectsMustAdhereToInterfaceTheyImplement_AcceptsAnObjectWithASubtypedInterfaceField_Union(t *testing.T) { + anotherInterface := graphql.NewInterface(graphql.InterfaceConfig{ + Name: "AnotherInterface", + ResolveType: func(value interface{}, info graphql.ResolveInfo) *graphql.Object { + return nil + }, + Fields: graphql.Fields{ + "field": &graphql.Field{ + Type: someUnionType, + }, + }, + }) + anotherObject := graphql.NewObject(graphql.ObjectConfig{ + Name: "AnotherObject", + Interfaces: []*graphql.Interface{anotherInterface}, + Fields: graphql.Fields{ + "field": &graphql.Field{ + Type: someObjectType, + }, + }, + }) + _, err := schemaWithFieldType(anotherObject) + if err != nil { + t.Fatalf(`unexpected error: %v for type "%v"`, err, anotherObject) + } +} func TestTypeSystem_ObjectsMustAdhereToInterfaceTheyImplement_RejectsAnObjectMissingAnInterfaceArgument(t *testing.T) { anotherInterface := graphql.NewInterface(graphql.InterfaceConfig{ Name: "AnotherInterface", @@ -1379,7 +1531,63 @@ func TestTypeSystem_ObjectsMustAdhereToInterfaceTheyImplement_AcceptsAnObjectWit t.Fatalf(`unexpected error: %v for type "%v"`, err, anotherObject) } } -func TestTypeSystem_ObjectsMustAdhereToInterfaceTheyImplement_RejectsAnObjectWithADifferentlyModifiedInterfaceFieldType(t *testing.T) { +func TestTypeSystem_ObjectsMustAdhereToInterfaceTheyImplement_RejectsAnObjectWithANonListInterfaceFieldListType(t *testing.T) { + anotherInterface := graphql.NewInterface(graphql.InterfaceConfig{ + Name: "AnotherInterface", + ResolveType: func(value interface{}, info graphql.ResolveInfo) *graphql.Object { + return nil + }, + Fields: graphql.Fields{ + "field": &graphql.Field{ + Type: graphql.NewList(graphql.String), + }, + }, + }) + anotherObject := graphql.NewObject(graphql.ObjectConfig{ + Name: "AnotherObject", + Interfaces: []*graphql.Interface{anotherInterface}, + Fields: graphql.Fields{ + "field": &graphql.Field{ + Type: graphql.String, + }, + }, + }) + _, err := schemaWithFieldType(anotherObject) + expectedError := `AnotherInterface.field expects type "[String]" but AnotherObject.field provides type "String".` + if err == nil || err.Error() != expectedError { + t.Fatalf("Expected error: %v, got %v", expectedError, err) + } +} + +func TestTypeSystem_ObjectsMustAdhereToInterfaceTheyImplement_RejectsAnObjectWithAListInterfaceFieldNonListType(t *testing.T) { + anotherInterface := graphql.NewInterface(graphql.InterfaceConfig{ + Name: "AnotherInterface", + ResolveType: func(value interface{}, info graphql.ResolveInfo) *graphql.Object { + return nil + }, + Fields: graphql.Fields{ + "field": &graphql.Field{ + Type: graphql.String, + }, + }, + }) + anotherObject := graphql.NewObject(graphql.ObjectConfig{ + Name: "AnotherObject", + Interfaces: []*graphql.Interface{anotherInterface}, + Fields: graphql.Fields{ + "field": &graphql.Field{ + Type: graphql.NewList(graphql.String), + }, + }, + }) + _, err := schemaWithFieldType(anotherObject) + expectedError := `AnotherInterface.field expects type "String" but AnotherObject.field provides type "[String]".` + if err == nil || err.Error() != expectedError { + t.Fatalf("Expected error: %v, got %v", expectedError, err) + } +} + +func TestTypeSystem_ObjectsMustAdhereToInterfaceTheyImplement_AcceptsAnObjectWithSubsetNonNullInterfaceFieldType(t *testing.T) { anotherInterface := graphql.NewInterface(graphql.InterfaceConfig{ Name: "AnotherInterface", ResolveType: func(value interface{}, info graphql.ResolveInfo) *graphql.Object { @@ -1400,8 +1608,35 @@ func TestTypeSystem_ObjectsMustAdhereToInterfaceTheyImplement_RejectsAnObjectWit }, }, }) - _, err := schemaWithObjectFieldOfType(anotherObject) - expectedError := `AnotherInterface.field expects type "String" but AnotherObject.field provides type "String!".` + _, err := schemaWithFieldType(anotherObject) + if err != nil { + t.Fatalf(`unexpected error: %v for type "%v"`, err, anotherObject) + } +} + +func TestTypeSystem_ObjectsMustAdhereToInterfaceTheyImplement_RejectsAnObjectWithASupersetNullableInterfaceFieldType(t *testing.T) { + anotherInterface := graphql.NewInterface(graphql.InterfaceConfig{ + Name: "AnotherInterface", + ResolveType: func(value interface{}, info graphql.ResolveInfo) *graphql.Object { + return nil + }, + Fields: graphql.Fields{ + "field": &graphql.Field{ + Type: graphql.NewNonNull(graphql.String), + }, + }, + }) + anotherObject := graphql.NewObject(graphql.ObjectConfig{ + Name: "AnotherObject", + Interfaces: []*graphql.Interface{anotherInterface}, + Fields: graphql.Fields{ + "field": &graphql.Field{ + Type: graphql.String, + }, + }, + }) + _, err := schemaWithFieldType(anotherObject) + expectedError := `AnotherInterface.field expects type "String!" but AnotherObject.field provides type "String".` if err == nil || err.Error() != expectedError { t.Fatalf("Expected error: %v, got %v", expectedError, err) } diff --git a/validator.go b/validator.go index 2873fd64..1e057935 100644 --- a/validator.go +++ b/validator.go @@ -12,6 +12,20 @@ type ValidationResult struct { Errors []gqlerrors.FormattedError } +/** + * Implements the "Validation" section of the spec. + * + * Validation runs synchronously, returning an array of encountered errors, or + * an empty array if no errors were encountered and the document is valid. + * + * A list of specific validation rules may be provided. If not provided, the + * default list of rules defined by the GraphQL specification will be used. + * + * Each validation rules is a function which returns a visitor + * (see the language/visitor API). Visitor methods are expected to return + * GraphQLErrors, or Arrays of GraphQLErrors when invalid. + */ + func ValidateDocument(schema *Schema, astDoc *ast.Document, rules []ValidationRuleFn) (vr ValidationResult) { if len(rules) == 0 { rules = SpecifiedRules @@ -26,150 +40,91 @@ func ValidateDocument(schema *Schema, astDoc *ast.Document, rules []ValidationRu vr.Errors = append(vr.Errors, gqlerrors.NewFormattedError("Must provide document")) return vr } - vr.Errors = visitUsingRules(schema, astDoc, rules) + + typeInfo := NewTypeInfo(&TypeInfoConfig{ + Schema: schema, + }) + vr.Errors = VisitUsingRules(schema, typeInfo, astDoc, rules) if len(vr.Errors) == 0 { vr.IsValid = true } return vr } -func visitUsingRules(schema *Schema, astDoc *ast.Document, rules []ValidationRuleFn) (errors []gqlerrors.FormattedError) { - typeInfo := NewTypeInfo(schema) - context := NewValidationContext(schema, astDoc, typeInfo) - - var visitInstance func(astNode ast.Node, instance *ValidationRuleInstance) - - visitInstance = func(astNode ast.Node, instance *ValidationRuleInstance) { - visitor.Visit(astNode, &visitor.VisitorOptions{ - Enter: func(p visitor.VisitFuncParams) (string, interface{}) { - var action = visitor.ActionNoChange - var result interface{} - switch node := p.Node.(type) { - case ast.Node: - // Collect type information about the current position in the AST. - typeInfo.Enter(node) - - // Do not visit top level fragment definitions if this instance will - // visit those fragments inline because it - // provided `visitSpreadFragments`. - kind := node.GetKind() - - if kind == kinds.FragmentDefinition && - p.Key != nil && instance.VisitSpreadFragments == true { - return visitor.ActionSkip, nil - } - - // Get the visitor function from the validation instance, and if it - // exists, call it with the visitor arguments. - enterFn := visitor.GetVisitFn(instance.VisitorOpts, false, kind) - if enterFn != nil { - action, result = enterFn(p) - } - - // If the visitor returned an error, log it and do not visit any - // deeper nodes. - if err, ok := result.(error); ok && err != nil { - errors = append(errors, gqlerrors.FormatError(err)) - action = visitor.ActionSkip - } - if err, ok := result.([]error); ok && err != nil { - errors = append(errors, gqlerrors.FormatErrors(err...)...) - action = visitor.ActionSkip - } - - // If any validation instances provide the flag `visitSpreadFragments` - // and this node is a fragment spread, visit the fragment definition - // from this point. - if action == visitor.ActionNoChange && result == nil && - instance.VisitSpreadFragments == true && kind == kinds.FragmentSpread { - node, _ := node.(*ast.FragmentSpread) - name := node.Name - nameVal := "" - if name != nil { - nameVal = name.Value - } - fragment := context.Fragment(nameVal) - if fragment != nil { - visitInstance(fragment, instance) - } - } +// VisitUsingRules This uses a specialized visitor which runs multiple visitors in parallel, +// while maintaining the visitor skip and break API. +// +// @internal +// Had to expose it to unit test experimental customizable validation feature, +// but not meant for public consumption +func VisitUsingRules(schema *Schema, typeInfo *TypeInfo, astDoc *ast.Document, rules []ValidationRuleFn) []gqlerrors.FormattedError { - // If the result is "false" (ie action === Action.Skip), we're not visiting any descendent nodes, - // but need to update typeInfo. - if action == visitor.ActionSkip { - typeInfo.Leave(node) - } + context := NewValidationContext(schema, astDoc, typeInfo) + visitors := []*visitor.VisitorOptions{} - } + for _, rule := range rules { + instance := rule(context) + visitors = append(visitors, instance.VisitorOpts) + } - return action, result - }, - Leave: func(p visitor.VisitFuncParams) (string, interface{}) { - var action = visitor.ActionNoChange - var result interface{} - switch node := p.Node.(type) { - case ast.Node: - kind := node.GetKind() - - // Get the visitor function from the validation instance, and if it - // exists, call it with the visitor arguments. - leaveFn := visitor.GetVisitFn(instance.VisitorOpts, true, kind) - if leaveFn != nil { - action, result = leaveFn(p) - } + // Visit the whole document with each instance of all provided rules. + visitor.Visit(astDoc, visitor.VisitWithTypeInfo(typeInfo, visitor.VisitInParallel(visitors...)), nil) + return context.Errors() +} - // If the visitor returned an error, log it and do not visit any - // deeper nodes. - if err, ok := result.(error); ok && err != nil { - errors = append(errors, gqlerrors.FormatError(err)) - action = visitor.ActionSkip - } - if err, ok := result.([]error); ok && err != nil { - errors = append(errors, gqlerrors.FormatErrors(err...)...) - action = visitor.ActionSkip - } +type HasSelectionSet interface { + GetKind() string + GetLoc() *ast.Location + GetSelectionSet() *ast.SelectionSet +} - // Update typeInfo. - typeInfo.Leave(node) - } - return action, result - }, - }, nil) - } +var _ HasSelectionSet = (*ast.OperationDefinition)(nil) +var _ HasSelectionSet = (*ast.FragmentDefinition)(nil) - instances := []*ValidationRuleInstance{} - for _, rule := range rules { - instance := rule(context) - instances = append(instances, instance) - } - for _, instance := range instances { - visitInstance(astDoc, instance) - } - return errors +type VariableUsage struct { + Node *ast.Variable + Type Input } type ValidationContext struct { - schema *Schema - astDoc *ast.Document - typeInfo *TypeInfo - fragments map[string]*ast.FragmentDefinition + schema *Schema + astDoc *ast.Document + typeInfo *TypeInfo + errors []gqlerrors.FormattedError + fragments map[string]*ast.FragmentDefinition + variableUsages map[HasSelectionSet][]*VariableUsage + recursiveVariableUsages map[*ast.OperationDefinition][]*VariableUsage + recursivelyReferencedFragments map[*ast.OperationDefinition][]*ast.FragmentDefinition + fragmentSpreads map[HasSelectionSet][]*ast.FragmentSpread } func NewValidationContext(schema *Schema, astDoc *ast.Document, typeInfo *TypeInfo) *ValidationContext { return &ValidationContext{ - schema: schema, - astDoc: astDoc, - typeInfo: typeInfo, + schema: schema, + astDoc: astDoc, + typeInfo: typeInfo, + fragments: map[string]*ast.FragmentDefinition{}, + variableUsages: map[HasSelectionSet][]*VariableUsage{}, + recursiveVariableUsages: map[*ast.OperationDefinition][]*VariableUsage{}, + recursivelyReferencedFragments: map[*ast.OperationDefinition][]*ast.FragmentDefinition{}, + fragmentSpreads: map[HasSelectionSet][]*ast.FragmentSpread{}, } } +func (ctx *ValidationContext) ReportError(err error) { + formattedErr := gqlerrors.FormatError(err) + ctx.errors = append(ctx.errors, formattedErr) +} +func (ctx *ValidationContext) Errors() []gqlerrors.FormattedError { + return ctx.errors +} + func (ctx *ValidationContext) Schema() *Schema { return ctx.schema } func (ctx *ValidationContext) Document() *ast.Document { return ctx.astDoc } - func (ctx *ValidationContext) Fragment(name string) *ast.FragmentDefinition { if len(ctx.fragments) == 0 { if ctx.Document() == nil { @@ -191,7 +146,128 @@ func (ctx *ValidationContext) Fragment(name string) *ast.FragmentDefinition { f, _ := ctx.fragments[name] return f } +func (ctx *ValidationContext) FragmentSpreads(node HasSelectionSet) []*ast.FragmentSpread { + if spreads, ok := ctx.fragmentSpreads[node]; ok && spreads != nil { + return spreads + } + + spreads := []*ast.FragmentSpread{} + setsToVisit := []*ast.SelectionSet{node.GetSelectionSet()} + + for { + if len(setsToVisit) == 0 { + break + } + var set *ast.SelectionSet + // pop + set, setsToVisit = setsToVisit[len(setsToVisit)-1], setsToVisit[:len(setsToVisit)-1] + if set.Selections != nil { + for _, selection := range set.Selections { + switch selection := selection.(type) { + case *ast.FragmentSpread: + spreads = append(spreads, selection) + case *ast.Field: + if selection.SelectionSet != nil { + setsToVisit = append(setsToVisit, selection.SelectionSet) + } + case *ast.InlineFragment: + if selection.SelectionSet != nil { + setsToVisit = append(setsToVisit, selection.SelectionSet) + } + } + } + } + ctx.fragmentSpreads[node] = spreads + } + return spreads +} + +func (ctx *ValidationContext) RecursivelyReferencedFragments(operation *ast.OperationDefinition) []*ast.FragmentDefinition { + if fragments, ok := ctx.recursivelyReferencedFragments[operation]; ok && fragments != nil { + return fragments + } + + fragments := []*ast.FragmentDefinition{} + collectedNames := map[string]bool{} + nodesToVisit := []HasSelectionSet{operation} + + for { + if len(nodesToVisit) == 0 { + break + } + + var node HasSelectionSet + + node, nodesToVisit = nodesToVisit[len(nodesToVisit)-1], nodesToVisit[:len(nodesToVisit)-1] + spreads := ctx.FragmentSpreads(node) + for _, spread := range spreads { + fragName := "" + if spread.Name != nil { + fragName = spread.Name.Value + } + if res, ok := collectedNames[fragName]; !ok || !res { + collectedNames[fragName] = true + fragment := ctx.Fragment(fragName) + if fragment != nil { + fragments = append(fragments, fragment) + nodesToVisit = append(nodesToVisit, fragment) + } + } + + } + } + + ctx.recursivelyReferencedFragments[operation] = fragments + return fragments +} +func (ctx *ValidationContext) VariableUsages(node HasSelectionSet) []*VariableUsage { + if usages, ok := ctx.variableUsages[node]; ok && usages != nil { + return usages + } + usages := []*VariableUsage{} + typeInfo := NewTypeInfo(&TypeInfoConfig{ + Schema: ctx.schema, + }) + + visitor.Visit(node, visitor.VisitWithTypeInfo(typeInfo, &visitor.VisitorOptions{ + KindFuncMap: map[string]visitor.NamedVisitFuncs{ + kinds.VariableDefinition: visitor.NamedVisitFuncs{ + Kind: func(p visitor.VisitFuncParams) (string, interface{}) { + return visitor.ActionSkip, nil + }, + }, + kinds.Variable: visitor.NamedVisitFuncs{ + Kind: func(p visitor.VisitFuncParams) (string, interface{}) { + if node, ok := p.Node.(*ast.Variable); ok && node != nil { + usages = append(usages, &VariableUsage{ + Node: node, + Type: typeInfo.InputType(), + }) + } + return visitor.ActionNoChange, nil + }, + }, + }, + }), nil) + + ctx.variableUsages[node] = usages + return usages +} +func (ctx *ValidationContext) RecursiveVariableUsages(operation *ast.OperationDefinition) []*VariableUsage { + if usages, ok := ctx.recursiveVariableUsages[operation]; ok && usages != nil { + return usages + } + usages := ctx.VariableUsages(operation) + + fragments := ctx.RecursivelyReferencedFragments(operation) + for _, fragment := range fragments { + fragmentUsages := ctx.VariableUsages(fragment) + usages = append(usages, fragmentUsages...) + } + ctx.recursiveVariableUsages[operation] = usages + return usages +} func (ctx *ValidationContext) Type() Output { return ctx.typeInfo.Type() } diff --git a/validator_test.go b/validator_test.go new file mode 100644 index 00000000..f7f19e57 --- /dev/null +++ b/validator_test.go @@ -0,0 +1,98 @@ +package graphql_test + +import ( + "testing" + + "github.com/graphql-go/graphql" + "github.com/graphql-go/graphql/gqlerrors" + "github.com/graphql-go/graphql/language/ast" + "github.com/graphql-go/graphql/language/location" + "github.com/graphql-go/graphql/language/parser" + "github.com/graphql-go/graphql/language/source" + "github.com/graphql-go/graphql/testutil" + "reflect" +) + +func expectValid(t *testing.T, schema *graphql.Schema, queryString string) { + source := source.NewSource(&source.Source{ + Body: queryString, + Name: "GraphQL request", + }) + AST, err := parser.Parse(parser.ParseParams{Source: source}) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + validationResult := graphql.ValidateDocument(schema, AST, nil) + + if !validationResult.IsValid || len(validationResult.Errors) > 0 { + t.Fatalf("Unexpected error: %v", validationResult.Errors) + } + +} + +func TestValidator_SupportsFullValidation_ValidatesQueries(t *testing.T) { + + expectValid(t, testutil.TestSchema, ` + query { + catOrDog { + ... on Cat { + furColor + } + ... on Dog { + isHousetrained + } + } + } + `) +} + +// NOTE: experimental +func TestValidator_SupportsFullValidation_ValidatesUsingACustomTypeInfo(t *testing.T) { + + // This TypeInfo will never return a valid field. + typeInfo := graphql.NewTypeInfo(&graphql.TypeInfoConfig{ + Schema: testutil.TestSchema, + FieldDefFn: func(schema *graphql.Schema, parentType graphql.Type, fieldAST *ast.Field) *graphql.FieldDefinition { + return nil + }, + }) + + ast := testutil.TestParse(t, ` + query { + catOrDog { + ... on Cat { + furColor + } + ... on Dog { + isHousetrained + } + } + } + `) + + errors := graphql.VisitUsingRules(testutil.TestSchema, typeInfo, ast, graphql.SpecifiedRules) + + expectedErrors := []gqlerrors.FormattedError{ + { + Message: "Cannot query field \"catOrDog\" on type \"QueryRoot\".", + Locations: []location.SourceLocation{ + {Line: 3, Column: 9}, + }, + }, + { + Message: "Cannot query field \"furColor\" on type \"Cat\".", + Locations: []location.SourceLocation{ + {Line: 5, Column: 13}, + }, + }, + { + Message: "Cannot query field \"isHousetrained\" on type \"Dog\".", + Locations: []location.SourceLocation{ + {Line: 8, Column: 13}, + }, + }, + } + if !reflect.DeepEqual(expectedErrors, errors) { + t.Fatalf("Unexpected result, Diff: %v", testutil.Diff(expectedErrors, errors)) + } +} diff --git a/values.go b/values.go index 6b3ff169..8ade746c 100644 --- a/values.go +++ b/values.go @@ -5,11 +5,13 @@ import ( "fmt" "math" "reflect" + "strings" "github.com/graphql-go/graphql/gqlerrors" "github.com/graphql-go/graphql/language/ast" "github.com/graphql-go/graphql/language/kinds" "github.com/graphql-go/graphql/language/printer" + "sort" ) // Prepares an object map of variableValues of the correct type based on the @@ -77,10 +79,12 @@ func getVariableValue(schema Schema, definitionAST *ast.VariableDefinition, inpu "", nil, []int{}, + nil, ) } - if isValidInputValue(input, ttype) { + isValid, messages := isValidInputValue(input, ttype) + if isValid { if isNullish(input) { defaultValue := definitionAST.DefaultValue if defaultValue != nil { @@ -99,20 +103,28 @@ func getVariableValue(schema Schema, definitionAST *ast.VariableDefinition, inpu "", nil, []int{}, + nil, ) } + // convert input interface into string for error message inputStr := "" b, err := json.Marshal(input) if err == nil { inputStr = string(b) } + messagesStr := "" + if len(messages) > 0 { + messagesStr = "\n" + strings.Join(messages, "\n") + } + return "", gqlerrors.NewError( - fmt.Sprintf(`Variable "$%v" expected value of type `+ - `"%v" but got: %v.`, variable.Name.Value, printer.Print(definitionAST.Type), inputStr), + fmt.Sprintf(`Variable "$%v" got invalid value `+ + `%v.%v`, variable.Name.Value, inputStr, messagesStr), []ast.Node{definitionAST}, "", nil, []int{}, + nil, ) } @@ -208,16 +220,19 @@ func typeFromAST(schema Schema, inputTypeAST ast.Type) (Type, error) { // Given a value and a GraphQL type, determine if the value will be // accepted for that type. This is primarily useful for validating the // runtime values of query variables. -func isValidInputValue(value interface{}, ttype Input) bool { +func isValidInputValue(value interface{}, ttype Input) (bool, []string) { if ttype, ok := ttype.(*NonNull); ok { if isNullish(value) { - return false + if ttype.OfType.Name() != "" { + return false, []string{fmt.Sprintf(`Expected "%v!", found null.`, ttype.OfType.Name())} + } + return false, []string{"Expected non-null value, found null."} } return isValidInputValue(value, ttype.OfType) } if isNullish(value) { - return true + return true, nil } switch ttype := ttype.(type) { @@ -228,48 +243,76 @@ func isValidInputValue(value interface{}, ttype Input) bool { valType = valType.Elem() } if valType.Kind() == reflect.Slice { + messagesReduce := []string{} for i := 0; i < valType.Len(); i++ { val := valType.Index(i).Interface() - if !isValidInputValue(val, itemType) { - return false + _, messages := isValidInputValue(val, itemType) + for idx, message := range messages { + messagesReduce = append(messagesReduce, fmt.Sprintf(`In element #%v: %v`, idx+1, message)) } } - return true + return (len(messagesReduce) == 0), messagesReduce } return isValidInputValue(value, itemType) case *InputObject: + messagesReduce := []string{} + valueMap, ok := value.(map[string]interface{}) if !ok { - return false + return false, []string{fmt.Sprintf(`Expected "%v", found not an object.`, ttype.Name())} } fields := ttype.Fields() + // to ensure stable order of field evaluation + fieldNames := []string{} + valueMapFieldNames := []string{} + + for fieldName := range fields { + fieldNames = append(fieldNames, fieldName) + } + sort.Strings(fieldNames) + + for fieldName := range valueMap { + valueMapFieldNames = append(valueMapFieldNames, fieldName) + } + sort.Strings(valueMapFieldNames) + // Ensure every provided field is defined. - for fieldName, _ := range valueMap { + for _, fieldName := range valueMapFieldNames { if _, ok := fields[fieldName]; !ok { - return false + messagesReduce = append(messagesReduce, fmt.Sprintf(`In field "%v": Unknown field.`, fieldName)) } } + // Ensure every defined field is valid. - for fieldName, _ := range fields { - isValid := isValidInputValue(valueMap[fieldName], fields[fieldName].Type) - if !isValid { - return false + for _, fieldName := range fieldNames { + _, messages := isValidInputValue(valueMap[fieldName], fields[fieldName].Type) + if messages != nil { + for _, message := range messages { + messagesReduce = append(messagesReduce, fmt.Sprintf(`In field "%v": %v`, fieldName, message)) + } } } - return true + return (len(messagesReduce) == 0), messagesReduce } switch ttype := ttype.(type) { case *Scalar: parsedVal := ttype.ParseValue(value) - return !isNullish(parsedVal) + if isNullish(parsedVal) { + return false, []string{fmt.Sprintf(`Expected type "%v", found "%v".`, ttype.Name(), value)} + } + return true, nil + case *Enum: parsedVal := ttype.ParseValue(value) - return !isNullish(parsedVal) + if isNullish(parsedVal) { + return false, []string{fmt.Sprintf(`Expected type "%v", found "%v".`, ttype.Name(), value)} + } + return true, nil } - return false + return true, nil } // Returns true if a value is null, undefined, or NaN. diff --git a/variables_test.go b/variables_test.go index b8e118f2..adfe823c 100644 --- a/variables_test.go +++ b/variables_test.go @@ -53,6 +53,18 @@ var testInputObject *graphql.InputObject = graphql.NewInputObject(graphql.InputO }, }) +var testNestedInputObject *graphql.InputObject = graphql.NewInputObject(graphql.InputObjectConfig{ + Name: "TestNestedInputObject", + Fields: graphql.InputObjectConfigFieldMap{ + "na": &graphql.InputObjectFieldConfig{ + Type: graphql.NewNonNull(testInputObject), + }, + "nb": &graphql.InputObjectFieldConfig{ + Type: graphql.NewNonNull(graphql.String), + }, + }, +}) + func inputResolved(p graphql.ResolveParams) (interface{}, error) { input, ok := p.Args["input"] if !ok { @@ -105,6 +117,16 @@ var testType *graphql.Object = graphql.NewObject(graphql.ObjectConfig{ }, Resolve: inputResolved, }, + "fieldWithNestedInputObject": &graphql.Field{ + Type: graphql.String, + Args: graphql.FieldConfigArgument{ + "input": &graphql.ArgumentConfig{ + Type: testNestedInputObject, + DefaultValue: "Hello World", + }, + }, + Resolve: inputResolved, + }, "list": &graphql.Field{ Type: graphql.String, Args: graphql.FieldConfigArgument{ @@ -369,8 +391,8 @@ func TestVariables_ObjectsAndNullability_UsingVariables_ErrorsOnNullForNestedNon Data: nil, Errors: []gqlerrors.FormattedError{ gqlerrors.FormattedError{ - Message: `Variable "$input" expected value of type "TestInputObject" but ` + - `got: {"a":"foo","b":"bar","c":null}.`, + Message: `Variable "$input" got invalid value {"a":"foo","b":"bar","c":null}.` + + "\nIn field \"c\": Expected \"String!\", found null.", Locations: []location.SourceLocation{ location.SourceLocation{ Line: 2, Column: 17, @@ -404,8 +426,7 @@ func TestVariables_ObjectsAndNullability_UsingVariables_ErrorsOnIncorrectType(t Data: nil, Errors: []gqlerrors.FormattedError{ gqlerrors.FormattedError{ - Message: `Variable "$input" expected value of type "TestInputObject" but ` + - `got: "foo bar".`, + Message: "Variable \"$input\" got invalid value \"foo bar\".\nExpected \"TestInputObject\", found not an object.", Locations: []location.SourceLocation{ location.SourceLocation{ Line: 2, Column: 17, @@ -442,8 +463,8 @@ func TestVariables_ObjectsAndNullability_UsingVariables_ErrorsOnOmissionOfNested Data: nil, Errors: []gqlerrors.FormattedError{ gqlerrors.FormattedError{ - Message: `Variable "$input" expected value of type "TestInputObject" but ` + - `got: {"a":"foo","b":"bar"}.`, + Message: `Variable "$input" got invalid value {"a":"foo","b":"bar"}.` + + "\nIn field \"c\": Expected \"String!\", found null.", Locations: []location.SourceLocation{ location.SourceLocation{ Line: 2, Column: 17, @@ -469,21 +490,66 @@ func TestVariables_ObjectsAndNullability_UsingVariables_ErrorsOnOmissionOfNested t.Fatalf("Unexpected result, Diff: %v", testutil.Diff(expected, result)) } } +func TestVariables_ObjectsAndNullability_UsingVariables_ErrorsOnDeepNestedErrorsAndWithManyErrors(t *testing.T) { + params := map[string]interface{}{ + "input": map[string]interface{}{ + "na": map[string]interface{}{ + "a": "foo", + }, + }, + } + expected := &graphql.Result{ + Data: nil, + Errors: []gqlerrors.FormattedError{ + gqlerrors.FormattedError{ + Message: `Variable "$input" got invalid value {"na":{"a":"foo"}}.` + + "\nIn field \"na\": In field \"c\": Expected \"String!\", found null." + + "\nIn field \"nb\": Expected \"String!\", found null.", + Locations: []location.SourceLocation{ + location.SourceLocation{ + Line: 2, Column: 19, + }, + }, + }, + }, + } + doc := ` + query q($input: TestNestedInputObject) { + fieldWithNestedObjectInput(input: $input) + } + ` + + nestedAST := testutil.TestParse(t, doc) + + // execute + ep := graphql.ExecuteParams{ + Schema: variablesTestSchema, + AST: nestedAST, + Args: params, + } + result := testutil.TestExecute(t, ep) + if len(result.Errors) != len(expected.Errors) { + t.Fatalf("Unexpected errors, Diff: %v", testutil.Diff(expected.Errors, result.Errors)) + } + if !reflect.DeepEqual(expected, result) { + t.Fatalf("Unexpected result, Diff: %v", testutil.Diff(expected, result)) + } +} func TestVariables_ObjectsAndNullability_UsingVariables_ErrorsOnAdditionOfUnknownInputField(t *testing.T) { params := map[string]interface{}{ "input": map[string]interface{}{ - "a": "foo", - "b": "bar", - "c": "baz", - "d": "dog", + "a": "foo", + "b": "bar", + "c": "baz", + "extra": "dog", }, } expected := &graphql.Result{ Data: nil, Errors: []gqlerrors.FormattedError{ gqlerrors.FormattedError{ - Message: `Variable "$input" expected value of type "TestInputObject" but ` + - `got: {"a":"foo","b":"bar","c":"baz","d":"dog"}.`, + Message: `Variable "$input" got invalid value {"a":"foo","b":"bar","c":"baz","extra":"dog"}.` + + "\nIn field \"extra\": Unknown field.", Locations: []location.SourceLocation{ location.SourceLocation{ Line: 2, Column: 17, @@ -1117,8 +1183,9 @@ func TestVariables_ListsAndNullability_DoesNotAllowListOfNonNullsToContainNull(t Data: nil, Errors: []gqlerrors.FormattedError{ gqlerrors.FormattedError{ - Message: `Variable "$input" expected value of type "[String!]" but got: ` + - `["A",null,"B"].`, + Message: `Variable "$input" got invalid value ` + + `["A",null,"B"].` + + "\nIn element #1: Expected \"String!\", found null.", Locations: []location.SourceLocation{ location.SourceLocation{ Line: 2, Column: 17, @@ -1224,8 +1291,9 @@ func TestVariables_ListsAndNullability_DoesNotAllowNonNullListOfNonNullsToContai Data: nil, Errors: []gqlerrors.FormattedError{ gqlerrors.FormattedError{ - Message: `Variable "$input" expected value of type "[String!]!" but got: ` + - `["A",null,"B"].`, + Message: `Variable "$input" got invalid value ` + + `["A",null,"B"].` + + "\nIn element #1: Expected \"String!\", found null.", Locations: []location.SourceLocation{ location.SourceLocation{ Line: 2, Column: 17,