diff --git a/src/GraphQL.Authorization.Tests/AuthorizationValidationRuleTests.cs b/src/GraphQL.Authorization.Tests/AuthorizationValidationRuleTests.cs index c6854be..c086d33 100644 --- a/src/GraphQL.Authorization.Tests/AuthorizationValidationRuleTests.cs +++ b/src/GraphQL.Authorization.Tests/AuthorizationValidationRuleTests.cs @@ -109,6 +109,40 @@ public void nested_type_policy_fail() }); } + [Fact] + public void passes_with_claim_on_input_type() + { + Settings.AddPolicy("FieldPolicy", _ => + { + _.RequireClaim("admin"); + }); + + ShouldPassRule(_=> + { + _.Query = @"query { author(input: { name: ""Quinn"" }) }"; + _.Schema = TypedSchema(); + _.User = CreatePrincipal(claims: new Dictionary + { + {"Admin", "true"} + }); + }); + } + + [Fact] + public void fails_on_missing_claim_on_input_type() + { + Settings.AddPolicy("FieldPolicy", _ => + { + _.RequireClaim("admin"); + }); + + ShouldFailRule(_=> + { + _.Query = @"query { author(input: { name: ""Quinn"" }) }"; + _.Schema = TypedSchema(); + }); + } + private ISchema BasicSchema() { var defs = @" @@ -172,5 +206,24 @@ public class Author { public string Name { get; set;} } + + private ISchema TypedSchema() + { + var query = new ObjectGraphType(); + query.Field( + "author", + arguments: new QueryArguments(new QueryArgument { Name = "input" }), + resolve: context => "testing" + ); + return new Schema { Query = query }; + } + + public class AuthorInputType : InputObjectGraphType + { + public AuthorInputType() + { + Field(x => x.Name).AuthorizeWith("FieldPolicy"); + } + } } } diff --git a/src/GraphQL.Authorization/AuthorizationValidationRule.cs b/src/GraphQL.Authorization/AuthorizationValidationRule.cs index 4f0a27b..098ce4c 100644 --- a/src/GraphQL.Authorization/AuthorizationValidationRule.cs +++ b/src/GraphQL.Authorization/AuthorizationValidationRule.cs @@ -41,7 +41,7 @@ public INodeVisitor Validate(ValidationContext context) if (argumentType == null) return; - var fieldType = argumentType.Fields.First(p => p.Name == objectFieldAst.Name); + var fieldType = argumentType.GetField(objectFieldAst.Name); CheckAuth(objectFieldAst, fieldType, userContext, context, operationType); });