diff --git a/src/GraphQL.Tests/Utilities/SchemaBuilderTests.cs b/src/GraphQL.Tests/Utilities/SchemaBuilderTests.cs index a521890b87..f5844a6c6b 100644 --- a/src/GraphQL.Tests/Utilities/SchemaBuilderTests.cs +++ b/src/GraphQL.Tests/Utilities/SchemaBuilderTests.cs @@ -822,6 +822,29 @@ public void build_extension_type_out_of_order() type.Fields.Count.ShouldBe(2); } + [Fact] + public void reads_directives_from_types_or_extension_types() + { + var schema = Schema.For(""" + extend type Query @directiveA { + field1: String + } + + type Query @directiveB { + field2: String + } + + directive @directiveA on OBJECT + directive @directiveB on OBJECT + """); + + schema.Initialize(); + var type = schema.AllTypes["Query"].ShouldNotBeNull(); + var directives = type.GetAppliedDirectives()?.List.ShouldNotBeNull(); + directives.Where(x => x.Name == "directiveA").ShouldHaveSingleItem(); + directives.Where(x => x.Name == "directiveB").ShouldHaveSingleItem(); + } + [Fact] public async Task builds_with_customized_clr_type() { diff --git a/src/GraphQL/Utilities/SchemaBuilderExtensions.cs b/src/GraphQL/Utilities/SchemaBuilderExtensions.cs index 2e1d3ef4bf..09702e3855 100644 --- a/src/GraphQL/Utilities/SchemaBuilderExtensions.cs +++ b/src/GraphQL/Utilities/SchemaBuilderExtensions.cs @@ -35,9 +35,18 @@ public static bool AstTypeHasFields(this IProvideMetadata type) { provider.WithMetadata(AST_METAFIELD, node); //TODO: remove? - if (node is IHasDirectivesNode ast && ast.Directives?.Count > 0) + if (node is IHasDirectivesNode ast) + provider.CopyDirectivesFrom(ast); + + return provider; + } + + public static TMetadataProvider CopyDirectivesFrom(this TMetadataProvider provider, IHasDirectivesNode node) + where TMetadataProvider : IProvideMetadata + { + if (node.Directives?.Count > 0) { - foreach (var directive in ast.Directives!) + foreach (var directive in node.Directives) { provider.ApplyDirective(directive!.Name.StringValue, d => //ISSUE:allocation { @@ -49,20 +58,23 @@ public static bool AstTypeHasFields(this IProvideMetadata type) }); } } - return provider; } public static bool HasExtensionAstTypes(this IProvideMetadata type) { - return GetExtensionAstTypes(type).Count > 0; + return type.HasMetadata(EXTENSION_AST_METAFIELD) && GetExtensionAstTypes(type).Count > 0; } - public static void AddExtensionAstType(this IProvideMetadata type, T astType) where T : ASTNode + public static void AddExtensionAstType(this IProvideMetadata type, T astType) + where T : ASTNode { var types = GetExtensionAstTypes(type); types.Add(astType); type.Metadata[EXTENSION_AST_METAFIELD] = types; + + if (astType is IHasDirectivesNode ast) + type.CopyDirectivesFrom(ast); } public static List GetExtensionAstTypes(this IProvideMetadata type)