Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add schema definition deep copying and mechanism to inspect or alter schema definitions before compilation #58

Merged
merged 1 commit into from
May 2, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 36 additions & 16 deletions config.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,11 @@ type Config struct {
// implementations that aren't explicitly referenced elsewhere in the schema.
AdditionalTypes map[string]graphql.NamedType

// If given, these function will be executed as the schema is built. It is executed on a clone
// of the schema and can be used to make last minute modifications to types, such as injecting
// documentation.
PreprocessGraphQLSchemaDefinition func(schema *graphql.SchemaDefinition) error

initOnce sync.Once
nodeObjectTypesByName map[string]*graphql.ObjectType
nodeTypesByModel map[reflect.Type]*NodeType
Expand Down Expand Up @@ -121,12 +126,12 @@ func (cfg *Config) init() {
})
}

func (cfg *Config) graphqlSchema() (*graphql.Schema, error) {
func (cfg *Config) graphqlSchemaDefinition() (*graphql.SchemaDefinition, error) {
additionalTypes := make([]graphql.NamedType, 0, len(cfg.AdditionalTypes))
for _, t := range cfg.AdditionalTypes {
additionalTypes = append(additionalTypes, t)
}
return graphql.NewSchema(&graphql.SchemaDefinition{
ret := &graphql.SchemaDefinition{
Query: cfg.query,
Mutation: cfg.mutation,
Subscription: cfg.subscription,
Expand All @@ -135,7 +140,22 @@ func (cfg *Config) graphqlSchema() (*graphql.Schema, error) {
"include": graphql.IncludeDirective,
"skip": graphql.SkipDirective,
},
})
}
if cfg.PreprocessGraphQLSchemaDefinition != nil {
ret = ret.Clone()
if err := cfg.PreprocessGraphQLSchemaDefinition(ret); err != nil {
return nil, err
}
}
return ret, nil
}

func (cfg *Config) graphqlSchema() (*graphql.Schema, error) {
def, err := cfg.graphqlSchemaDefinition()
if err != nil {
return nil, err
}
return graphql.NewSchema(def)
}

// NodeObjectType returns the object type for a node type previously added via AddNodeType.
Expand Down Expand Up @@ -211,19 +231,19 @@ func (cfg *Config) AddMutation(name string, def *graphql.FieldDefinition) {
// When this happens, you should return a pointer to a SubscriptionSourceStream (or an error). For
// example:
//
// Resolve: func(ctx *graphql.FieldContext) (interface{}, error) {
// if ctx.IsSubscribe {
// ticker := time.NewTicker(time.Second)
// return &apifu.SubscriptionSourceStream{
// EventChannel: ticker.C,
// Stop: ticker.Stop,
// }, nil
// } else if ctx.Object != nil {
// return ctx.Object, nil
// } else {
// return nil, fmt.Errorf("Subscriptions are not supported using this protocol.")
// }
// },
// Resolve: func(ctx *graphql.FieldContext) (interface{}, error) {
// if ctx.IsSubscribe {
// ticker := time.NewTicker(time.Second)
// return &apifu.SubscriptionSourceStream{
// EventChannel: ticker.C,
// Stop: ticker.Stop,
// }, nil
// } else if ctx.Object != nil {
// return ctx.Object, nil
// } else {
// return nil, fmt.Errorf("Subscriptions are not supported using this protocol.")
// }
// },
func (cfg *Config) AddSubscription(name string, def *graphql.FieldDefinition) {
cfg.init()

Expand Down
266 changes: 266 additions & 0 deletions graphql/schema/deep_copy.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,266 @@
package schema

import "fmt"

func deepCopySchemaDefinition(def *SchemaDefinition) *SchemaDefinition {
newNamedTypes := make(map[string]NamedType)

// Create shallow copies for all the named types.
Inspect(def, func(node any) bool {
if node, ok := node.(NamedType); ok {
if _, ok := newNamedTypes[node.TypeName()]; ok {
return false
}
switch t := node.(type) {
case *UnionType:
copy := *t
newNamedTypes[t.Name] = &copy
case *InterfaceType:
copy := *t
newNamedTypes[t.Name] = &copy
case *InputObjectType:
copy := *t
newNamedTypes[t.Name] = &copy
case *ObjectType:
copy := *t
newNamedTypes[t.Name] = &copy
case *EnumType:
copy := *t
newNamedTypes[t.Name] = &copy
case *ScalarType:
copy := *t
newNamedTypes[t.Name] = &copy
default:
panic(fmt.Errorf("unknown named type type: %T", t))
}
}

return true
})

// Now update all of those shallow copies to point to each other.
for _, t := range newNamedTypes {
fixNamedTypePointers(t, newNamedTypes)
}

ret := &SchemaDefinition{}
if def.Query != nil {
ret.Query = newNamedTypes[def.Query.Name].(*ObjectType)
}
if def.Mutation != nil {
ret.Mutation = newNamedTypes[def.Mutation.Name].(*ObjectType)
}
if def.Subscription != nil {
ret.Subscription = newNamedTypes[def.Subscription.Name].(*ObjectType)
}

if def.Directives != nil {
ret.Directives = make(map[string]*DirectiveDefinition, len(def.Directives))
for k, v := range def.Directives {
newValue := *v
fixNamedTypePointers(&newValue, newNamedTypes)
ret.Directives[k] = &newValue
}
}

if def.AdditionalTypes != nil {
ret.AdditionalTypes = make([]NamedType, len(def.AdditionalTypes))
for i, v := range def.AdditionalTypes {
ret.AdditionalTypes[i] = newNamedTypes[v.TypeName()]
}
}

return ret
}

func fixTypePointer(t Type, namedTypes map[string]NamedType) Type {
switch t := t.(type) {
case NamedType:
if _, ok := BuiltInTypes[t.TypeName()]; ok {
return t
} else if ret, ok := namedTypes[t.TypeName()]; ok {
return ret
}
return t
case *NonNullType:
return NewNonNullType(fixTypePointer(t.Unwrap(), namedTypes))
case *ListType:
return NewListType(fixTypePointer(t.Unwrap(), namedTypes))
default:
panic(fmt.Errorf("unknown named type type: %T", t))
}
return nil
}

// Updates pointers to named types to those contained in the given map. This function does not
// recurse into descendant named types.
func fixNamedTypePointers(node any, namedTypes map[string]NamedType) {
switch n := node.(type) {
case *UnionType:
if n.Directives != nil {
newValues := make([]*Directive, len(n.Directives))
for i, v := range n.Directives {
newValue := *v
fixNamedTypePointers(&newValue, namedTypes)
newValues[i] = &newValue
}
n.Directives = newValues
}
if n.MemberTypes != nil {
newValues := make([]*ObjectType, len(n.MemberTypes))
for i, v := range n.MemberTypes {
if newValue, ok := namedTypes[v.Name].(*ObjectType); ok {
newValues[i] = newValue
} else {
newValues[i] = v
}
}
n.MemberTypes = newValues
}
case *InterfaceType:
if n.Directives != nil {
newValues := make([]*Directive, len(n.Directives))
for i, v := range n.Directives {
newValue := *v
fixNamedTypePointers(&newValue, namedTypes)
newValues[i] = &newValue
}
n.Directives = newValues
}
if n.Fields != nil {
newValues := make(map[string]*FieldDefinition, len(n.Fields))
for k, v := range n.Fields {
newField := *v
fixNamedTypePointers(&newField, namedTypes)
newValues[k] = &newField
}
n.Fields = newValues
}
case *InputObjectType:
if n.Directives != nil {
newValues := make([]*Directive, len(n.Directives))
for i, v := range n.Directives {
newValue := *v
fixNamedTypePointers(&newValue, namedTypes)
newValues[i] = &newValue
}
n.Directives = newValues
}
if n.Fields != nil {
newValues := make(map[string]*InputValueDefinition, len(n.Fields))
for k, v := range n.Fields {
newField := *v
fixNamedTypePointers(&newField, namedTypes)
newValues[k] = &newField
}
n.Fields = newValues
}
case *ObjectType:
if n.Directives != nil {
newValues := make([]*Directive, len(n.Directives))
for i, v := range n.Directives {
newValue := *v
fixNamedTypePointers(&newValue, namedTypes)
newValues[i] = &newValue
}
n.Directives = newValues
}
if n.Fields != nil {
newValues := make(map[string]*FieldDefinition, len(n.Fields))
for k, v := range n.Fields {
newField := *v
fixNamedTypePointers(&newField, namedTypes)
newValues[k] = &newField
}
n.Fields = newValues
}
if n.ImplementedInterfaces != nil {
newValues := make([]*InterfaceType, len(n.ImplementedInterfaces))
for i, v := range n.ImplementedInterfaces {
if newValue, ok := namedTypes[v.Name].(*InterfaceType); ok {
newValues[i] = newValue
} else {
newValues[i] = v
}
}
n.ImplementedInterfaces = newValues
}
case *FieldDefinition:
if n.Directives != nil {
newValues := make([]*Directive, len(n.Directives))
for i, v := range n.Directives {
newValue := *v
fixNamedTypePointers(&newValue, namedTypes)
newValues[i] = &newValue
}
n.Directives = newValues
}
n.Type = fixTypePointer(n.Type, namedTypes)
if n.Arguments != nil {
newValues := make(map[string]*InputValueDefinition, len(n.Arguments))
for k, v := range n.Arguments {
newField := *v
fixNamedTypePointers(&newField, namedTypes)
newValues[k] = &newField
}
n.Arguments = newValues
}
case *InputValueDefinition:
if n.Directives != nil {
newValues := make([]*Directive, len(n.Directives))
for i, v := range n.Directives {
newValue := *v
fixNamedTypePointers(&newValue, namedTypes)
newValues[i] = &newValue
}
n.Directives = newValues
}
n.Type = fixTypePointer(n.Type, namedTypes)
case *Directive:
if n.Definition != nil {
newDefinition := *n.Definition
fixNamedTypePointers(&newDefinition, namedTypes)
n.Definition = &newDefinition
}
case *DirectiveDefinition:
if n.Arguments != nil {
newValues := make(map[string]*InputValueDefinition, len(n.Arguments))
for k, v := range n.Arguments {
newField := *v
fixNamedTypePointers(&newField, namedTypes)
newValues[k] = &newField
}
n.Arguments = newValues
}
case *EnumType:
if n.Directives != nil {
newValues := make([]*Directive, len(n.Directives))
for i, v := range n.Directives {
newValue := *v
fixNamedTypePointers(&newValue, namedTypes)
newValues[i] = &newValue
}
n.Directives = newValues
}
if n.Values != nil {
newValues := make(map[string]*EnumValueDefinition, len(n.Values))
for k, v := range n.Values {
newValue := *v
newValues[k] = &newValue
}
n.Values = newValues
}
case *ScalarType:
if n.Directives != nil {
newValues := make([]*Directive, len(n.Directives))
for i, v := range n.Directives {
newValue := *v
fixNamedTypePointers(&newValue, namedTypes)
newValues[i] = &newValue
}
n.Directives = newValues
}
default:
panic(fmt.Errorf("unexpected node type: %T", n))
}
}
Loading