Skip to content
Merged
33 changes: 31 additions & 2 deletions src/Security/Authorization/Core/src/AuthorizationPolicy.cs
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,24 @@ public static AuthorizationPolicy Combine(IEnumerable<AuthorizationPolicy> polic
/// A new <see cref="AuthorizationPolicy"/> which represents the combination of the
/// authorization policies provided by the specified <paramref name="policyProvider"/>.
/// </returns>
public static async Task<AuthorizationPolicy?> CombineAsync(IAuthorizationPolicyProvider policyProvider, IEnumerable<IAuthorizeData> authorizeData)
public static Task<AuthorizationPolicy?> CombineAsync(IAuthorizationPolicyProvider policyProvider,
IEnumerable<IAuthorizeData> authorizeData) => CombineAsync(policyProvider, authorizeData,
Enumerable.Empty<AuthorizationPolicy>());

/// <summary>
/// Combines the <see cref="AuthorizationPolicy"/> provided by the specified
/// <paramref name="policyProvider"/>.
/// </summary>
/// <param name="policyProvider">A <see cref="IAuthorizationPolicyProvider"/> which provides the policies to combine.</param>
/// <param name="authorizeData">A collection of authorization data used to apply authorization to a resource.</param>
/// <param name="policies">A collection of <see cref="AuthorizationPolicy"/> policies to combine.</param>
/// <returns>
/// A new <see cref="AuthorizationPolicy"/> which represents the combination of the
/// authorization policies provided by the specified <paramref name="policyProvider"/>.
/// </returns>
public static async Task<AuthorizationPolicy?> CombineAsync(IAuthorizationPolicyProvider policyProvider,
IEnumerable<IAuthorizeData> authorizeData,
IEnumerable<AuthorizationPolicy> policies)
{
if (policyProvider == null)
{
Expand All @@ -120,6 +137,8 @@ public static AuthorizationPolicy Combine(IEnumerable<AuthorizationPolicy> polic
throw new ArgumentNullException(nameof(authorizeData));
}

var anyPolicies = policies.Any();

// Avoid allocating enumerator if the data is known to be empty
var skipEnumeratingData = false;
if (authorizeData is IList<IAuthorizeData> dataList)
Expand All @@ -137,7 +156,7 @@ public static AuthorizationPolicy Combine(IEnumerable<AuthorizationPolicy> polic
policyBuilder = new AuthorizationPolicyBuilder();
}

var useDefaultPolicy = true;
var useDefaultPolicy = !(anyPolicies);
if (!string.IsNullOrWhiteSpace(authorizeDatum.Policy))
{
var policy = await policyProvider.GetPolicyAsync(authorizeDatum.Policy).ConfigureAwait(false);
Expand Down Expand Up @@ -176,6 +195,16 @@ public static AuthorizationPolicy Combine(IEnumerable<AuthorizationPolicy> polic
}
}

if (anyPolicies)
{
policyBuilder ??= new();

foreach (var policy in policies)
{
policyBuilder.Combine(policy);
}
}

// If we have no policy by now, use the fallback policy if we have one
if (policyBuilder == null)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,4 @@
*REMOVED*~Microsoft.AspNetCore.Authorization.DefaultAuthorizationService.DefaultAuthorizationService(Microsoft.AspNetCore.Authorization.IAuthorizationPolicyProvider! policyProvider, Microsoft.AspNetCore.Authorization.IAuthorizationHandlerProvider! handlers, Microsoft.Extensions.Logging.ILogger<Microsoft.AspNetCore.Authorization.DefaultAuthorizationService!>! logger, Microsoft.AspNetCore.Authorization.IAuthorizationHandlerContextFactory! contextFactory, Microsoft.AspNetCore.Authorization.IAuthorizationEvaluator! evaluator, Microsoft.Extensions.Options.IOptions<Microsoft.AspNetCore.Authorization.AuthorizationOptions!>! options) -> void
Microsoft.AspNetCore.Authorization.DefaultAuthorizationPolicyProvider.DefaultAuthorizationPolicyProvider(Microsoft.Extensions.Options.IOptions<Microsoft.AspNetCore.Authorization.AuthorizationOptions!>! options) -> void
Microsoft.AspNetCore.Authorization.DefaultAuthorizationService.DefaultAuthorizationService(Microsoft.AspNetCore.Authorization.IAuthorizationPolicyProvider! policyProvider, Microsoft.AspNetCore.Authorization.IAuthorizationHandlerProvider! handlers, Microsoft.Extensions.Logging.ILogger<Microsoft.AspNetCore.Authorization.DefaultAuthorizationService!>! logger, Microsoft.AspNetCore.Authorization.IAuthorizationHandlerContextFactory! contextFactory, Microsoft.AspNetCore.Authorization.IAuthorizationEvaluator! evaluator, Microsoft.Extensions.Options.IOptions<Microsoft.AspNetCore.Authorization.AuthorizationOptions!>! options) -> void
static Microsoft.AspNetCore.Authorization.AuthorizationPolicy.CombineAsync(Microsoft.AspNetCore.Authorization.IAuthorizationPolicyProvider! policyProvider, System.Collections.Generic.IEnumerable<Microsoft.AspNetCore.Authorization.IAuthorizeData!>! authorizeData, System.Collections.Generic.IEnumerable<Microsoft.AspNetCore.Authorization.AuthorizationPolicy!>! policies) -> System.Threading.Tasks.Task<Microsoft.AspNetCore.Authorization.AuthorizationPolicy?>!
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,55 @@ public static TBuilder RequireAuthorization<TBuilder>(this TBuilder builder, par
return builder;
}

/// <summary>
/// Adds an authorization policy to the endpoint(s).
/// </summary>
/// <param name="builder">The endpoint convention builder.</param>
/// <param name="policy">The <see cref="AuthorizationPolicy"/> policy.</param>
/// <returns>The original convention builder parameter.</returns>
public static TBuilder RequireAuthorization<TBuilder>(this TBuilder builder, AuthorizationPolicy policy)
where TBuilder : IEndpointConventionBuilder
{
if (builder == null)
{
throw new ArgumentNullException(nameof(builder));
}

if (policy == null)
{
throw new ArgumentNullException(nameof(policy));
}

RequirePolicyCore(builder, policy);
return builder;
}

/// <summary>
/// Adds an new authorization policy configured by a callback to the endpoint(s).
/// </summary>
/// <typeparam name="TBuilder"></typeparam>
/// <param name="builder">The endpoint convention builder.</param>
/// <param name="configurePolicy">The callback used to configure the policy.</param>
/// <returns>The original convention builder parameter.</returns>
public static TBuilder RequireAuthorization<TBuilder>(this TBuilder builder, Action<AuthorizationPolicyBuilder> configurePolicy)
where TBuilder : IEndpointConventionBuilder
{
if (builder == null)
{
throw new ArgumentNullException(nameof(builder));
}

if (configurePolicy == null)
{
throw new ArgumentNullException(nameof(configurePolicy));
}

var policyBuilder = new AuthorizationPolicyBuilder();
configurePolicy(policyBuilder);
RequirePolicyCore(builder, policyBuilder.Build());
return builder;
}

/// <summary>
/// Allows anonymous access to the endpoint by adding <see cref="AllowAnonymousAttribute" /> to the endpoint metadata. This will bypass
/// all authorization checks for the endpoint including the default authorization policy and fallback authorization policy.
Expand All @@ -94,6 +143,20 @@ public static TBuilder AllowAnonymous<TBuilder>(this TBuilder builder) where TBu
return builder;
}

private static void RequirePolicyCore<TBuilder>(TBuilder builder, AuthorizationPolicy policy)
where TBuilder : IEndpointConventionBuilder
{
builder.Add(endpointBuilder =>
{
// Only add an authorize attribute if there isn't one
if (!endpointBuilder.Metadata.Any(meta => meta is IAuthorizeData))
{
endpointBuilder.Metadata.Add(new AuthorizeAttribute());
}
endpointBuilder.Metadata.Add(policy);
});
}

private static void RequireAuthorizationCore<TBuilder>(TBuilder builder, IEnumerable<IAuthorizeData> authorizeData)
where TBuilder : IEndpointConventionBuilder
{
Expand All @@ -105,4 +168,5 @@ private static void RequireAuthorizationCore<TBuilder>(TBuilder builder, IEnumer
}
});
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,11 @@ public async Task Invoke(HttpContext context)

// IMPORTANT: Changes to authorization logic should be mirrored in MVC's AuthorizeFilter
var authorizeData = endpoint?.Metadata.GetOrderedMetadata<IAuthorizeData>() ?? Array.Empty<IAuthorizeData>();
var policy = await AuthorizationPolicy.CombineAsync(_policyProvider, authorizeData);

var policies = endpoint?.Metadata.GetOrderedMetadata<AuthorizationPolicy>() ?? Array.Empty<AuthorizationPolicy>();

var policy = await AuthorizationPolicy.CombineAsync(_policyProvider, authorizeData, policies);

if (policy == null)
{
await _next(context);
Expand Down
Original file line number Diff line number Diff line change
@@ -1 +1,3 @@
#nullable enable
static Microsoft.AspNetCore.Builder.AuthorizationEndpointConventionBuilderExtensions.RequireAuthorization<TBuilder>(this TBuilder builder, Microsoft.AspNetCore.Authorization.AuthorizationPolicy! policy) -> TBuilder
static Microsoft.AspNetCore.Builder.AuthorizationEndpointConventionBuilderExtensions.RequireAuthorization<TBuilder>(this TBuilder builder, System.Action<Microsoft.AspNetCore.Authorization.AuthorizationPolicyBuilder!>! configurePolicy) -> TBuilder
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,106 @@ public void RequireAuthorization_ChainedCall()
Assert.True(chainedBuilder.TestProperty);
}

[Fact]
public void RequireAuthorization_Policy()
{
// Arrange
var builder = new TestEndpointConventionBuilder();
var policy = new AuthorizationPolicyBuilder().RequireAssertion(_ => true).Build();

// Act
builder.RequireAuthorization(policy);

// Assert
var convention = Assert.Single(builder.Conventions);

var endpointModel = new RouteEndpointBuilder((context) => Task.CompletedTask, RoutePatternFactory.Parse("/"), 0);
convention(endpointModel);

Assert.Equal(2, endpointModel.Metadata.Count);
var authMetadata = Assert.IsAssignableFrom<IAuthorizeData>(endpointModel.Metadata[0]);
Assert.Null(authMetadata.Policy);

Assert.Equal(policy, endpointModel.Metadata[1]);
}

[Fact]
public void RequireAuthorization_PolicyCallback()
{
// Arrange
var builder = new TestEndpointConventionBuilder();
var requirement = new TestRequirement();

// Act
builder.RequireAuthorization(policyBuilder => policyBuilder.Requirements.Add(requirement));

// Assert
var convention = Assert.Single(builder.Conventions);

var endpointModel = new RouteEndpointBuilder((context) => Task.CompletedTask, RoutePatternFactory.Parse("/"), 0);
convention(endpointModel);

Assert.Equal(2, endpointModel.Metadata.Count);
var authMetadata = Assert.IsAssignableFrom<IAuthorizeData>(endpointModel.Metadata[0]);
Assert.Null(authMetadata.Policy);

var policy = Assert.IsAssignableFrom<AuthorizationPolicy>(endpointModel.Metadata[1]);
Assert.Equal(1, policy.Requirements.Count);
Assert.Equal(requirement, policy.Requirements[0]);
}

[Fact]
public void RequireAuthorization_PolicyCallbackWithAuthorize()
{
// Arrange
var builder = new TestEndpointConventionBuilder();
var authorize = new AuthorizeAttribute();
var requirement = new TestRequirement();

// Act
builder.RequireAuthorization(policyBuilder => policyBuilder.Requirements.Add(requirement));

// Assert
var convention = Assert.Single(builder.Conventions);

var endpointModel = new RouteEndpointBuilder((context) => Task.CompletedTask, RoutePatternFactory.Parse("/"), 0);
endpointModel.Metadata.Add(authorize);
convention(endpointModel);

// Confirm that we don't add another authorize if one already exists
Assert.Equal(2, endpointModel.Metadata.Count);
Assert.Equal(authorize, endpointModel.Metadata[0]);
var policy = Assert.IsAssignableFrom<AuthorizationPolicy>(endpointModel.Metadata[1]);
Assert.Equal(1, policy.Requirements.Count);
Assert.Equal(requirement, policy.Requirements[0]);
}

[Fact]
public void RequireAuthorization_PolicyWithAuthorize()
{
// Arrange
var builder = new TestEndpointConventionBuilder();
var policy = new AuthorizationPolicyBuilder().RequireAssertion(_ => true).Build();
var authorize = new AuthorizeAttribute();

// Act
builder.RequireAuthorization(policy);

// Assert
var convention = Assert.Single(builder.Conventions);

var endpointModel = new RouteEndpointBuilder((context) => Task.CompletedTask, RoutePatternFactory.Parse("/"), 0);
endpointModel.Metadata.Add(authorize);
convention(endpointModel);

// Confirm that we don't add another authorize if one already exists
Assert.Equal(2, endpointModel.Metadata.Count);
Assert.Equal(authorize, endpointModel.Metadata[0]);
Assert.Equal(policy, endpointModel.Metadata[1]);
}

class TestRequirement : IAuthorizationRequirement { }

[Fact]
public void AllowAnonymous_Default()
{
Expand Down
22 changes: 22 additions & 0 deletions src/Security/Authorization/test/AuthorizationMiddlewareTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,28 @@ public async Task OnAuthorizationAsync_WillCallPolicyProvider()
Assert.Equal(3, next.CalledCount);
}

[Fact]
public async Task CanApplyPolicyDirectlyToEndpoint()
{
// Arrange
var calledPolicy = false;
var policy = new AuthorizationPolicyBuilder().RequireAssertion(_ =>
{
calledPolicy = true;
return true;
}).Build();

var policyProvider = new Mock<IAuthorizationPolicyProvider>();
policyProvider.Setup(p => p.GetDefaultPolicyAsync()).ReturnsAsync(new AuthorizationPolicyBuilder().RequireAuthenticatedUser().Build());
var next = new TestRequestDelegate();
var middleware = CreateMiddleware(next.Invoke, policyProvider.Object);
var context = GetHttpContext(anonymous: false, endpoint: CreateEndpoint(new AuthorizeAttribute(), policy));

// Act & Assert
await middleware.Invoke(context);
Assert.True(calledPolicy);
}

[Fact]
public async Task Invoke_ValidClaimShouldNotFail()
{
Expand Down
23 changes: 23 additions & 0 deletions src/Security/Authorization/test/AuthorizationPolicyFacts.cs
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,29 @@ public async Task CanCombineAuthorizeAttributes()
Assert.Single(combined.Requirements.OfType<RolesAuthorizationRequirement>());
}

[Fact]
public async Task CanReplaceDefaultPolicyDirectly()
{
// Arrange
var attributes = new AuthorizeAttribute[] {
new AuthorizeAttribute(),
new AuthorizeAttribute(),
};

var policies = new[] { new AuthorizationPolicyBuilder().RequireAssertion(_ => true).Build() };

var options = new AuthorizationOptions();

var provider = new DefaultAuthorizationPolicyProvider(Options.Create(options));

// Act
var combined = await AuthorizationPolicy.CombineAsync(provider, attributes, policies);

// Assert
Assert.Equal(1, combined.Requirements.Count);
Assert.Empty(combined.Requirements.OfType<DenyAnonymousAuthorizationRequirement>());
}

[Fact]
public async Task CanReplaceDefaultPolicy()
{
Expand Down
Loading