Skip to content

Commit

Permalink
Merge pull request #2442 from captainsafia/security-schemes-selector
Browse files Browse the repository at this point in the history
Add support for SecuritySchemesSelector and default implementation
  • Loading branch information
domaindrivendev committed Jul 13, 2022
2 parents 23fe15d + 5b501e3 commit b2f185d
Show file tree
Hide file tree
Showing 7 changed files with 216 additions and 14 deletions.
13 changes: 13 additions & 0 deletions src/Swashbuckle.AspNetCore.Swagger/IAsyncSwaggerProvider.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
using System.Threading.Tasks;
using Microsoft.OpenApi.Models;

namespace Swashbuckle.AspNetCore.Swagger
{
public interface IAsyncSwaggerProvider
{
Task<OpenApiDocument> GetSwaggerAsync(
string documentName,
string host = null,
string basePath = null);
}
}
15 changes: 11 additions & 4 deletions src/Swashbuckle.AspNetCore.Swagger/SwaggerMiddleware.cs
Original file line number Diff line number Diff line change
Expand Up @@ -39,10 +39,17 @@ public async Task Invoke(HttpContext httpContext, ISwaggerProvider swaggerProvid
? httpContext.Request.PathBase.Value
: null;

var swagger = swaggerProvider.GetSwagger(
documentName: documentName,
host: null,
basePath: basePath);
var swagger = swaggerProvider switch
{
IAsyncSwaggerProvider asyncSwaggerProvider => await asyncSwaggerProvider.GetSwaggerAsync(
documentName: documentName,
host: null,
basePath: basePath),
_ => swaggerProvider.GetSwagger(
documentName: documentName,
host: null,
basePath: basePath)
};

// One last opportunity to modify the Swagger Document - this time with request context
foreach (var filter in _options.PreSerializeFilters)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,12 @@ internal class DocumentProvider : IDocumentProvider
{
private readonly SwaggerGeneratorOptions _generatorOptions;
private readonly SwaggerOptions _options;
private readonly ISwaggerProvider _swaggerProvider;
private readonly IAsyncSwaggerProvider _swaggerProvider;

public DocumentProvider(
IOptions<SwaggerGeneratorOptions> generatorOptions,
IOptions<SwaggerOptions> options,
ISwaggerProvider swaggerProvider)
IAsyncSwaggerProvider swaggerProvider)
{
_generatorOptions = generatorOptions.Value;
_options = options.Value;
Expand All @@ -40,10 +40,10 @@ public IEnumerable<string> GetDocumentNames()
return _generatorOptions.SwaggerDocs.Keys;
}

public Task GenerateAsync(string documentName, TextWriter writer)
public async Task GenerateAsync(string documentName, TextWriter writer)
{
// Let UnknownSwaggerDocument or other exception bubble up to caller.
var swagger = _swaggerProvider.GetSwagger(documentName, host: null, basePath: null);
var swagger = await _swaggerProvider.GetSwaggerAsync(documentName, host: null, basePath: null);
var jsonWriter = new OpenApiJsonWriter(writer);
if (_options.SerializeAsV2)
{
Expand All @@ -53,8 +53,6 @@ public Task GenerateAsync(string documentName, TextWriter writer)
{
swagger.SerializeAsV3(jsonWriter);
}

return Task.CompletedTask;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ public static class SwaggerGenServiceCollectionExtensions

// Register generator and it's dependencies
services.TryAddTransient<ISwaggerProvider, SwaggerGenerator>();
services.TryAddTransient<IAsyncSwaggerProvider, SwaggerGenerator>();
services.TryAddTransient(s => s.GetRequiredService<IOptions<SwaggerGeneratorOptions>>().Value);
services.TryAddTransient<ISchemaGenerator, SchemaGenerator>();
services.TryAddTransient(s => s.GetRequiredService<IOptions<SchemaGeneratorOptions>>().Value);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
using System.Linq;
using System.Reflection;
using System.Text.RegularExpressions;
using System.Threading.Tasks;
using Microsoft.AspNetCore.Authentication;
using Microsoft.AspNetCore.Mvc;
using Microsoft.AspNetCore.Mvc.ApiExplorer;
using Microsoft.AspNetCore.Mvc.ModelBinding;
Expand All @@ -11,11 +13,12 @@

namespace Swashbuckle.AspNetCore.SwaggerGen
{
public class SwaggerGenerator : ISwaggerProvider
public class SwaggerGenerator : ISwaggerProvider, IAsyncSwaggerProvider
{
private readonly IApiDescriptionGroupCollectionProvider _apiDescriptionsProvider;
private readonly ISchemaGenerator _schemaGenerator;
private readonly SwaggerGeneratorOptions _options;
private readonly IAuthenticationSchemeProvider _authenticationSchemeProvider;

public SwaggerGenerator(
SwaggerGeneratorOptions options,
Expand All @@ -27,7 +30,30 @@ public class SwaggerGenerator : ISwaggerProvider
_schemaGenerator = schemaGenerator;
}

public SwaggerGenerator(
SwaggerGeneratorOptions options,
IApiDescriptionGroupCollectionProvider apiDescriptionsProvider,
ISchemaGenerator schemaGenerator,
IAuthenticationSchemeProvider authentiationSchemeProvider) : this(options, apiDescriptionsProvider, schemaGenerator)
{
_authenticationSchemeProvider = authentiationSchemeProvider;
}

public async Task<OpenApiDocument> GetSwaggerAsync(string documentName, string host = null, string basePath = null)
{
var (applicableApiDescriptions, swaggerDoc, schemaRepository) = GetSwaggerDocument(documentName, host, basePath);
swaggerDoc.Components.SecuritySchemes = await GetSecuritySchemes();
return swaggerDoc;
}

public OpenApiDocument GetSwagger(string documentName, string host = null, string basePath = null)
{
var (applicableApiDescriptions, swaggerDoc, schemaRepository) = GetSwaggerDocument(documentName, host, basePath);
swaggerDoc.Components.SecuritySchemes = GetSecuritySchemes().Result;
return swaggerDoc;
}

private (IEnumerable<ApiDescription>, OpenApiDocument, SchemaRepository) GetSwaggerDocument(string documentName, string host = null, string basePath = null)
{
if (!_options.SwaggerDocs.TryGetValue(documentName, out OpenApiInfo info))
throw new UnknownSwaggerDocument(documentName, _options.SwaggerDocs.Select(d => d.Key));
Expand All @@ -47,7 +73,6 @@ public OpenApiDocument GetSwagger(string documentName, string host = null, strin
Components = new OpenApiComponents
{
Schemas = schemaRepository.Schemas,
SecuritySchemes = new Dictionary<string, OpenApiSecurityScheme>(_options.SecuritySchemes)
},
SecurityRequirements = new List<OpenApiSecurityRequirement>(_options.SecurityRequirements)
};
Expand All @@ -60,7 +85,30 @@ public OpenApiDocument GetSwagger(string documentName, string host = null, strin

swaggerDoc.Components.Schemas = new SortedDictionary<string, OpenApiSchema>(swaggerDoc.Components.Schemas, _options.SchemaComparer);

return swaggerDoc;
return (applicableApiDescriptions, swaggerDoc, schemaRepository);
}

private async Task<Dictionary<string, OpenApiSecurityScheme>> GetSecuritySchemes()
{
var securitySchemes = new Dictionary<string, OpenApiSecurityScheme>(_options.SecuritySchemes);
var authenticationSchemes = Enumerable.Empty<AuthenticationScheme>();
if (_authenticationSchemeProvider is not null)
{
authenticationSchemes = await _authenticationSchemeProvider.GetAllSchemesAsync();
}
var securitySchemesFromSelector = _options.SecuritySchemesSelector(authenticationSchemes);
// Favor security schemes set via options over those generated
// from the selector. For the default selector, this effectively
// ends up favoring `Bearer` authentication types explicitly set
// by the user over those derived by the selector.
foreach (var securityScheme in securitySchemesFromSelector)
{
if (!securitySchemes.ContainsKey(securityScheme.Key))
{
securitySchemes.Add(securityScheme.Key, securityScheme.Value);
}
}
return securitySchemes;
}

private IList<OpenApiServer> GenerateServers(string host, string basePath)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
using Microsoft.AspNetCore.Mvc.ApiExplorer;
using Microsoft.OpenApi.Models;
using Microsoft.AspNetCore.Routing;
using Microsoft.AspNetCore.Authentication;

namespace Swashbuckle.AspNetCore.SwaggerGen
{
Expand All @@ -19,6 +20,7 @@ public SwaggerGeneratorOptions()
OperationIdSelector = DefaultOperationIdSelector;
TagsSelector = DefaultTagsSelector;
SortKeySelector = DefaultSortKeySelector;
SecuritySchemesSelector = DefaultSecuritySchemeSelector;
SchemaComparer = StringComparer.Ordinal;
Servers = new List<OpenApiServer>();
SecuritySchemes = new Dictionary<string, OpenApiSecurityScheme>();
Expand Down Expand Up @@ -61,6 +63,8 @@ public SwaggerGeneratorOptions()

public IList<IDocumentFilter> DocumentFilters { get; set; }

public Func<IEnumerable<AuthenticationScheme>, Dictionary<string, OpenApiSecurityScheme>> SecuritySchemesSelector { get; set;}

private bool DefaultDocInclusionPredicate(string documentName, ApiDescription apiDescription)
{
return apiDescription.GroupName == null || apiDescription.GroupName == documentName;
Expand Down Expand Up @@ -102,5 +106,26 @@ private string DefaultSortKeySelector(ApiDescription apiDescription)
{
return TagsSelector(apiDescription).First();
}

private Dictionary<string, OpenApiSecurityScheme> DefaultSecuritySchemeSelector(IEnumerable<AuthenticationScheme> schemes)
{
Dictionary<string, OpenApiSecurityScheme> securitySchemes = new();
#if (NET6_0_OR_GREATER)
foreach (var scheme in schemes)
{
if (scheme.Name == "Bearer")
{
securitySchemes[scheme.Name] = new OpenApiSecurityScheme
{
Type = SecuritySchemeType.Http,
Scheme = "bearer", // "bearer" refers to the header name here
In = ParameterLocation.Header,
BearerFormat = "Json Web Token"
};
}
}
#endif
return securitySchemes;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@
using Swashbuckle.AspNetCore.Swagger;
using Swashbuckle.AspNetCore.TestSupport;
using Xunit;
using System.Threading.Tasks;
using Microsoft.AspNetCore.Authentication;
using Microsoft.AspNetCore.Server.HttpSys;

namespace Swashbuckle.AspNetCore.SwaggerGen.Test
{
Expand Down Expand Up @@ -911,9 +914,78 @@ public void GetSwagger_SupportsOption_SecuritySchemes()

var document = subject.GetSwagger("v1");

Assert.Equal(new[] { "basic", "Bearer" }, document.Components.SecuritySchemes.Keys);
}

[Fact]
public async Task GetSwagger_SupportsSecuritySchemesSelector()
{
var subject = Subject(
apiDescriptions: new ApiDescription[] { },
options: new SwaggerGeneratorOptions
{
SwaggerDocs = new Dictionary<string, OpenApiInfo>
{
["v1"] = new OpenApiInfo { Version = "V1", Title = "Test API" }
},
SecuritySchemesSelector = (schemes) => new Dictionary<string, OpenApiSecurityScheme>
{
["basic"] = new OpenApiSecurityScheme { Type = SecuritySchemeType.Http, Scheme = "basic" }
}
}
);

var document = await subject.GetSwaggerAsync("v1");

// Overrides the default set of [basic, bearer] with just [basic]
Assert.Equal(new[] { "basic" }, document.Components.SecuritySchemes.Keys);
}

[Fact]
public async Task GetSwagger_DefaultSecuritySchemeSelectorAddsBearerByDefault()
{
var subject = Subject(
apiDescriptions: new ApiDescription[] { },
options: new SwaggerGeneratorOptions
{
SwaggerDocs = new Dictionary<string, OpenApiInfo>
{
["v1"] = new OpenApiInfo { Version = "V1", Title = "Test API" }
},
}
);

var document = await subject.GetSwaggerAsync("v1");

Assert.Equal(new[] { "Bearer" }, document.Components.SecuritySchemes.Keys);
}

[Fact]
public async Task GetSwagger_DefaultSecuritySchemesSelectorDoesNotOverrideBearer()
{
var subject = Subject(
apiDescriptions: new ApiDescription[] { },
options: new SwaggerGeneratorOptions
{
SwaggerDocs = new Dictionary<string, OpenApiInfo>
{
["v1"] = new OpenApiInfo { Version = "V1", Title = "Test API" }
},
SecuritySchemes = new Dictionary<string, OpenApiSecurityScheme>
{
["Bearer"] = new OpenApiSecurityScheme { Type = SecuritySchemeType.ApiKey, Scheme = "someSpecialOne" }
}
}
);

var document = await subject.GetSwaggerAsync("v1");

var securityScheme = Assert.Single(document.Components.SecuritySchemes);
Assert.Equal("Bearer", securityScheme.Key);
Assert.Equal(SecuritySchemeType.ApiKey, securityScheme.Value.Type);
Assert.Equal("someSpecialOne", securityScheme.Value.Scheme);
}

[Fact]
public void GetSwagger_SupportsOption_ParameterFilters()
{
Expand Down Expand Up @@ -1049,7 +1121,8 @@ private SwaggerGenerator Subject(IEnumerable<ApiDescription> apiDescriptions, Sw
return new SwaggerGenerator(
options ?? DefaultOptions,
new FakeApiDescriptionGroupCollectionProvider(apiDescriptions),
new SchemaGenerator(new SchemaGeneratorOptions(), new JsonSerializerDataContractResolver(new JsonSerializerOptions()))
new SchemaGenerator(new SchemaGeneratorOptions(), new JsonSerializerDataContractResolver(new JsonSerializerOptions())),
new TestAuthenticationSchemeProvider()
);
}

Expand All @@ -1061,4 +1134,41 @@ private SwaggerGenerator Subject(IEnumerable<ApiDescription> apiDescriptions, Sw
}
};
}

class TestAuthenticationSchemeProvider : IAuthenticationSchemeProvider
{
private readonly IEnumerable<AuthenticationScheme> _authenticationSchemes = new AuthenticationScheme[]
{
new AuthenticationScheme("Bearer", null, typeof(IAuthenticationHandler))
};

public void AddScheme(AuthenticationScheme scheme)
=> throw new NotImplementedException();
public Task<IEnumerable<AuthenticationScheme>> GetAllSchemesAsync()
=> Task.FromResult(_authenticationSchemes);

public Task<AuthenticationScheme> GetDefaultAuthenticateSchemeAsync()
=> Task.FromResult(_authenticationSchemes.First());

public Task<AuthenticationScheme> GetDefaultChallengeSchemeAsync()
=> Task.FromResult(_authenticationSchemes.First());

public Task<AuthenticationScheme> GetDefaultForbidSchemeAsync()
=> Task.FromResult(_authenticationSchemes.First());

public Task<AuthenticationScheme> GetDefaultSignInSchemeAsync()
=> Task.FromResult(_authenticationSchemes.First());

public Task<AuthenticationScheme> GetDefaultSignOutSchemeAsync()
=> Task.FromResult(_authenticationSchemes.First());

public Task<IEnumerable<AuthenticationScheme>> GetRequestHandlerSchemesAsync()
=> throw new NotImplementedException();

public Task<AuthenticationScheme> GetSchemeAsync(string name)
=> Task.FromResult(_authenticationSchemes.First());

public void RemoveScheme(string name)
=> throw new NotImplementedException();
}
}

0 comments on commit b2f185d

Please sign in to comment.