Skip to content

Commit

Permalink
cost validation fixes and tests (#57)
Browse files Browse the repository at this point in the history
  • Loading branch information
ccbrown committed Apr 11, 2023
1 parent fe8d220 commit 0ebcee0
Show file tree
Hide file tree
Showing 3 changed files with 64 additions and 35 deletions.
67 changes: 32 additions & 35 deletions graphql/validator/validate_cost.go
Original file line number Diff line number Diff line change
Expand Up @@ -83,51 +83,48 @@ func ValidateCost(operationName string, variableValues map[string]interface{}, m
if node == nil {
multipliers = multipliers[:len(multipliers)-1]
ctxs = ctxs[:len(ctxs)-1]
return true
}

multiplier := multipliers[len(multipliers)-1]
ctx := ctxs[len(ctxs)-1]
newMultiplier := multiplier
newCtx := ctx

if selectionSet, ok := node.(*ast.SelectionSet); ok {
for _, selection := range selectionSet.Selections {
switch selection := selection.(type) {
case *ast.Field:
if def, ok := typeInfo.FieldDefinitions[selection]; ok && coercedVariableValues != nil {
if args, err := CoerceArgumentValues(selection, def.Arguments, selection.Arguments, coercedVariableValues); err != nil {
ret = append(ret, newSecondaryError(selection, err.Error()))
} else {
costContext := schema.FieldCostContext{
Context: ctx,
Arguments: args,
}
fieldCost := defaultCost
if def.Cost != nil {
fieldCost = def.Cost(&costContext)
}
cost = checkedNonNegativeAdd(cost, checkedNonNegativeMultiply(multiplier, fieldCost.Resolver))
if fieldCost.Multiplier > 1 {
newMultiplier = checkedNonNegativeMultiply(multiplier, fieldCost.Multiplier)
}
if fieldCost.Context != nil {
newCtx = fieldCost.Context
}
}
} else if selection.Name.Name != "__typename" {
ret = append(ret, newSecondaryError(selection, "unknown field type"))
switch selection := node.(type) {
case *ast.Field:
if def, ok := typeInfo.FieldDefinitions[selection]; ok && coercedVariableValues != nil {
if args, err := CoerceArgumentValues(selection, def.Arguments, selection.Arguments, coercedVariableValues); err != nil {
ret = append(ret, newSecondaryError(selection, err.Error()))
} else {
costContext := schema.FieldCostContext{
Context: ctx,
Arguments: args,
}
case *ast.FragmentSpread:
if _, ok := fragments[selection.FragmentName.Name]; ok {
ret = append(ret, newSecondaryError(selection, "fragment cycle detected"))
} else if def, ok := fragmentsByName[selection.FragmentName.Name]; ok {
fragments[selection.FragmentName.Name] = struct{}{}
visitNode(def)
delete(fragments, selection.FragmentName.Name)
} else {
ret = append(ret, newSecondaryError(selection, "undefined fragment"))
fieldCost := defaultCost
if def.Cost != nil {
fieldCost = def.Cost(&costContext)
}
cost = checkedNonNegativeAdd(cost, checkedNonNegativeMultiply(multiplier, fieldCost.Resolver))
if fieldCost.Multiplier > 1 {
newMultiplier = checkedNonNegativeMultiply(multiplier, fieldCost.Multiplier)
}
if fieldCost.Context != nil {
newCtx = fieldCost.Context
}
}
} else if selection.Name.Name != "__typename" {
ret = append(ret, newSecondaryError(selection, "unknown field type"))
}
case *ast.FragmentSpread:
if _, ok := fragments[selection.FragmentName.Name]; ok {
ret = append(ret, newSecondaryError(selection, "fragment cycle detected"))
} else if def, ok := fragmentsByName[selection.FragmentName.Name]; ok {
fragments[selection.FragmentName.Name] = struct{}{}
visitNode(def)
delete(fragments, selection.FragmentName.Name)
} else {
ret = append(ret, newSecondaryError(selection, "undefined fragment"))
}
}

Expand Down
5 changes: 5 additions & 0 deletions graphql/validator/validate_cost_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,11 @@ func TestValidateCost(t *testing.T) {
"TypeName": {
Source: `{__typename t:__typename}`,
},
"Context": {
Source: `{a: objectWithCostContext(cost: 10) { costFromContext }, b: objectWithCostContext(cost: 100) { costFromContext }}`,
ExpectedCost: 110,
MaxCost: 1000,
},
"Multiplier": {
Source: `{objects(first: 10) { int }}`,
ExpectedCost: 1 + 10,
Expand Down
27 changes: 27 additions & 0 deletions graphql/validator/validator_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package validator

import (
"context"
"testing"

"github.com/stretchr/testify/assert"
Expand Down Expand Up @@ -73,6 +74,10 @@ var fooBarEnumType = &schema.EnumType{
},
}

type costContextKeyType int

var costContextKey costContextKeyType

func init() {
objectType.Fields = map[string]*schema.FieldDefinition{
"freeBoolean": {
Expand Down Expand Up @@ -199,6 +204,28 @@ func init() {
}
},
},
"objectWithCostContext": {
Type: objectType,
Arguments: map[string]*schema.InputValueDefinition{
"cost": {
Type: schema.IntType,
},
},
Cost: func(ctx *schema.FieldCostContext) schema.FieldCost {
cost, _ := ctx.Arguments["cost"].(int)
return schema.FieldCost{
Context: context.WithValue(ctx.Context, costContextKey, cost),
}
},
},
"costFromContext": {
Type: schema.IntType,
Cost: func(ctx *schema.FieldCostContext) schema.FieldCost {
return schema.FieldCost{
Resolver: ctx.Context.Value(costContextKey).(int),
}
},
},
"objects": {
Type: schema.NewListType(objectType),
Arguments: map[string]*schema.InputValueDefinition{
Expand Down

0 comments on commit 0ebcee0

Please sign in to comment.