Skip to content

Commit

Permalink
Honor Required/BindRequired attributes on all parameters
Browse files Browse the repository at this point in the history
  • Loading branch information
domaindrivendev committed Jun 14, 2018
1 parent 44b15b8 commit c977702
Show file tree
Hide file tree
Showing 27 changed files with 265 additions and 252 deletions.
5 changes: 4 additions & 1 deletion README.md
Expand Up @@ -479,7 +479,10 @@ When selecting actions for a given Swagger document, the generator invokes a _Do
```csharp
c.DocInclusionPredicate((docName, apiDesc) =>
{
var versions = apiDesc.ControllerAttributes()
if (!apiDesc.TryGetMethodInfo(out MethodInfo methodInfo)) return false;

var versions = methodInfo.DeclaringType
.GetCustomAttributes(true)
.OfType<ApiVersionAttribute>()
.SelectMany(attr => attr.Versions);

Expand Down
@@ -1,6 +1,7 @@
using System;
using System.Linq;
using System.Reflection;
using System.Collections.Generic;
using Swashbuckle.AspNetCore.Swagger;

namespace Swashbuckle.AspNetCore.SwaggerGen
Expand All @@ -9,16 +10,18 @@ public class SwaggerAttributesOperationFilter : IOperationFilter
{
public void Apply(Operation operation, OperationFilterContext context)
{
if (context.ControllerActionDescriptor == null) return;
if (context.MethodInfo == null) return;

ApplyOperationAttributes(operation, context);
ApplyOperationFilterAttributes(operation, context);
var actionAttributes = context.MethodInfo.GetCustomAttributes(true);
var controllerAttributes = context.MethodInfo.DeclaringType.GetTypeInfo().GetCustomAttributes(true);

ApplyOperationAttributes(operation, actionAttributes);
ApplyOperationFilterAttributes(operation, actionAttributes, controllerAttributes, context);
}

private static void ApplyOperationAttributes(Operation operation, OperationFilterContext context)
private static void ApplyOperationAttributes(Operation operation, IEnumerable<object> actionAttributes)
{
var swaggerOperationAttribute = context.ControllerActionDescriptor.MethodInfo
.GetCustomAttributes(true)
var swaggerOperationAttribute = actionAttributes
.OfType<SwaggerOperationAttribute>()
.FirstOrDefault();

Expand All @@ -34,10 +37,13 @@ private static void ApplyOperationAttributes(Operation operation, OperationFilte
operation.Schemes = swaggerOperationAttribute.Schemes;
}

public static void ApplyOperationFilterAttributes(Operation operation, OperationFilterContext context)
public static void ApplyOperationFilterAttributes(
Operation operation,
IEnumerable<object> actionAttributes,
IEnumerable<object> controllerAttributes,
OperationFilterContext context)
{
var swaggerOperationFilterAttributes = context.ControllerActionDescriptor
.GetControllerAndActionAttributes(true)
var swaggerOperationFilterAttributes = actionAttributes.Union(controllerAttributes)
.OfType<SwaggerOperationFilterAttribute>();

foreach (var swaggerOperationFilterAttribute in swaggerOperationFilterAttributes)
Expand Down
@@ -1,5 +1,6 @@
using System.Collections.Generic;
using System.Linq;
using System.Linq;
using System.Reflection;
using System.Collections.Generic;
using Swashbuckle.AspNetCore.Swagger;

namespace Swashbuckle.AspNetCore.SwaggerGen
Expand All @@ -8,10 +9,10 @@ public class SwaggerResponseAttributeFilter : IOperationFilter
{
public void Apply(Operation operation, OperationFilterContext context)
{
if (context.ControllerActionDescriptor == null) return;
if (context.MethodInfo == null) return;

var swaggerResponseAttributes = context.ControllerActionDescriptor
.GetControllerAndActionAttributes(true)
var swaggerResponseAttributes = context.MethodInfo.GetCustomAttributes(true)
.Union(context.MethodInfo.DeclaringType.GetTypeInfo().GetCustomAttributes(true))
.OfType<SwaggerResponseAttribute>();

if (!swaggerResponseAttributes.Any())
Expand Down
Expand Up @@ -10,7 +10,7 @@ namespace Swashbuckle.AspNetCore.SwaggerGen
{
public static class ApiDescriptionExtensions
{
[Obsolete("Deprecated: Use OperationFilterContext.ControllerActionDescriptor")]
[Obsolete("Deprecated: Use TryGetMethodInfo")]
public static IEnumerable<object> ControllerAttributes(this ApiDescription apiDescription)
{
var controllerActionDescriptor = apiDescription.ActionDescriptor as ControllerActionDescriptor;
Expand All @@ -19,7 +19,7 @@ public static IEnumerable<object> ControllerAttributes(this ApiDescription apiDe
: controllerActionDescriptor.ControllerTypeInfo.GetCustomAttributes(true);
}

[Obsolete("Deprecated: Use OperationFilterContext.ControllerActionDescriptor")]
[Obsolete("Deprecated: Use TryGetMethodInfo")]

This comment has been minimized.

Copy link
@mattfrear

mattfrear Jul 11, 2018

Contributor

I came here wondering why these methods are now marked as Obsolete. I use them and I now get compiler warnings when I compile against Swashbuckle.AspNetCore 3.0. Is there an issue which explains this change which you could point me to?

public static IEnumerable<object> ActionAttributes(this ApiDescription apiDescription)
{
var controllerActionDescriptor = apiDescription.ActionDescriptor as ControllerActionDescriptor;
Expand All @@ -28,6 +28,15 @@ public static IEnumerable<object> ActionAttributes(this ApiDescription apiDescri
: controllerActionDescriptor.MethodInfo.GetCustomAttributes(true);
}

public static bool TryGetMethodInfo(this ApiDescription apiDescription, out MethodInfo methodInfo)
{
var controllerActionDescriptor = apiDescription.ActionDescriptor as ControllerActionDescriptor;

methodInfo = controllerActionDescriptor?.MethodInfo;

return (methodInfo != null);
}

internal static string FriendlyId(this ApiDescription apiDescription)
{
var parts = (apiDescription.RelativePathSansQueryString() + "/" + apiDescription.HttpMethod.ToLower())
Expand Down Expand Up @@ -68,10 +77,13 @@ internal static IEnumerable<string> SupportedResponseMediaTypes(this ApiDescript

internal static bool IsObsolete(this ApiDescription apiDescription)
{
var controllerActionDescriptor = apiDescription.ActionDescriptor as ControllerActionDescriptor;
return (controllerActionDescriptor != null)
? controllerActionDescriptor.GetControllerAndActionAttributes(true).OfType<ObsoleteAttribute>().Any()
: false;
if (!apiDescription.TryGetMethodInfo(out MethodInfo methodInfo))
return false;

return methodInfo.GetCustomAttributes(true)
.Union(methodInfo.DeclaringType.GetTypeInfo().GetCustomAttributes(true))
.Any(attr => attr.GetType() == typeof(ObsoleteAttribute));
}

}
}
@@ -1,4 +1,7 @@
using Microsoft.AspNetCore.Mvc.ApiExplorer;
using System.Linq;
using System.Reflection;
using Microsoft.AspNetCore.Mvc.ApiExplorer;
using Microsoft.AspNetCore.Mvc.Controllers;
using Microsoft.AspNetCore.Mvc.ModelBinding;

namespace Swashbuckle.AspNetCore.SwaggerGen
Expand All @@ -15,18 +18,35 @@ public static bool IsPartOfCancellationToken(this ApiParameterDescription parame
|| name.StartsWith("WaitHandle.");
}

public static bool IsRequired(this ApiParameterDescription parameterDescription)
internal static bool TryGetParameterInfo(
this ApiParameterDescription apiParameterDescription,
ApiDescription apiDescription,
out ParameterInfo parameterInfo)
{
if (parameterDescription.RouteInfo?.IsOptional == false)
return true;
var controllerParameterDescriptor = apiDescription.ActionDescriptor.Parameters
.OfType<ControllerParameterDescriptor>()
.FirstOrDefault(descriptor =>
{
return (apiParameterDescription.Name == descriptor.BindingInfo?.BinderModelName)
|| (apiParameterDescription.Name == descriptor.Name);
});

if (parameterDescription.ModelMetadata?.IsBindingRequired == true)
return true;
parameterInfo = controllerParameterDescriptor?.ParameterInfo;

if (parameterDescription.ModelMetadata?.IsRequired == true && parameterDescription.Type.IsAssignableToNull())
return true;
return (parameterInfo != null);
}

internal static bool TryGetPropertyInfo(
this ApiParameterDescription apiParameterDescription,
out PropertyInfo propertyInfo)
{
var modelMetadata = apiParameterDescription.ModelMetadata;

propertyInfo = (modelMetadata?.ContainerType != null)
? modelMetadata.ContainerType.GetProperty(modelMetadata.PropertyName)
: null;

return false;
return (propertyInfo != null);
}
}
}
}

This file was deleted.

@@ -1,4 +1,5 @@
using Microsoft.AspNetCore.Mvc.ApiExplorer;
using System.Reflection;
using Microsoft.AspNetCore.Mvc.ApiExplorer;
using Microsoft.AspNetCore.Mvc.Controllers;
using Swashbuckle.AspNetCore.Swagger;

Expand All @@ -13,18 +14,18 @@ public class OperationFilterContext
{
public OperationFilterContext(
ApiDescription apiDescription,
ISchemaRegistry schemaRegistry)
ISchemaRegistry schemaRegistry,
MethodInfo methodInfo)
{
ApiDescription = apiDescription;
ControllerActionDescriptor = apiDescription.ActionDescriptor as ControllerActionDescriptor;
SchemaRegistry = schemaRegistry;
MethodInfo = methodInfo;
}

public ApiDescription ApiDescription { get; private set; }

public ControllerActionDescriptor ControllerActionDescriptor { get; }

public ISchemaRegistry SchemaRegistry { get; private set; }

public MethodInfo MethodInfo { get; }
}
}
@@ -1,5 +1,7 @@
using Microsoft.AspNetCore.Mvc.ApiExplorer;
using Microsoft.AspNetCore.Mvc.Controllers;
using System;
using System.Reflection;
using System.Collections.Generic;
using Microsoft.AspNetCore.Mvc.ApiExplorer;
using Swashbuckle.AspNetCore.Swagger;

namespace Swashbuckle.AspNetCore.SwaggerGen
Expand All @@ -13,18 +15,22 @@ public class ParameterFilterContext
{
public ParameterFilterContext(
ApiParameterDescription apiParameterDescription,
ControllerParameterDescriptor controllerParameterDescriptor,
ISchemaRegistry schemaRegistry)
ISchemaRegistry schemaRegistry,
ParameterInfo parameterInfo,
PropertyInfo propertyInfo)
{
ApiParameterDescription = apiParameterDescription;
ControllerParameterDescriptor = controllerParameterDescriptor;
SchemaRegistry = schemaRegistry;
ParameterInfo = parameterInfo;
PropertyInfo = propertyInfo;
}

public ApiParameterDescription ApiParameterDescription { get; }

public ControllerParameterDescriptor ControllerParameterDescriptor { get; }

public ISchemaRegistry SchemaRegistry { get; }

public ParameterInfo ParameterInfo { get; }

public PropertyInfo PropertyInfo { get; }
}
}
Expand Up @@ -17,7 +17,7 @@ internal static bool IsRequired(this JsonProperty jsonProperty)
if (jsonProperty.Required == Newtonsoft.Json.Required.Always)
return true;

if (jsonProperty.HasAttribute<RequiredAttribute>() && jsonProperty.PropertyType.IsAssignableToNull())
if (jsonProperty.HasAttribute<RequiredAttribute>())
return true;

return false;
Expand Down
58 changes: 31 additions & 27 deletions src/Swashbuckle.AspNetCore.SwaggerGen/Generator/SwaggerGenerator.cs
Expand Up @@ -2,12 +2,11 @@
using System.Collections.Generic;
using System.Linq;
using System.Text.RegularExpressions;
using System.Reflection;
using System.ComponentModel.DataAnnotations;
using Microsoft.AspNetCore.Mvc.ApiExplorer;
using Microsoft.AspNetCore.Mvc.ModelBinding;
using Microsoft.AspNetCore.Mvc.Controllers;
using Microsoft.AspNetCore.Mvc.ModelBinding.Metadata;
using Swashbuckle.AspNetCore.Swagger;
using System.Reflection;

namespace Swashbuckle.AspNetCore.SwaggerGen
{
Expand Down Expand Up @@ -141,6 +140,18 @@ public class SwaggerGenerator : ISwaggerProvider
ApiDescription apiDescription,
ISchemaRegistry schemaRegistry)
{
// Try to retrieve additional metadata that's not provided by ApiExplorer
MethodInfo methodInfo;

var customAttributes = Enumerable.Empty<object>();
if (apiDescription.TryGetMethodInfo(out methodInfo))
{
customAttributes = methodInfo.GetCustomAttributes(true)
.Union(methodInfo.DeclaringType.GetTypeInfo().GetCustomAttributes(true));
}

var isDeprecated = customAttributes.Any(attr => attr.GetType() == typeof(ObsoleteAttribute));

var operation = new Operation
{
Tags = new[] { _settings.TagSelector(apiDescription) },
Expand All @@ -149,12 +160,13 @@ public class SwaggerGenerator : ISwaggerProvider
Produces = apiDescription.SupportedResponseMediaTypes().ToList(),
Parameters = CreateParameters(apiDescription, schemaRegistry),
Responses = CreateResponses(apiDescription, schemaRegistry),
Deprecated = apiDescription.IsObsolete() ? true : (bool?)null
Deprecated = isDeprecated ? true : (bool?)null
};

var filterContext = new OperationFilterContext(
apiDescription,
schemaRegistry);
schemaRegistry,
methodInfo);

foreach (var filter in _settings.OperationFilters)
{
Expand Down Expand Up @@ -198,19 +210,28 @@ public class SwaggerGenerator : ISwaggerProvider
? schemaRegistry.GetOrRegister(apiParameterDescription.Type)
: null;

var isRequired = apiParameterDescription.IsRequired();
// Try to retrieve additional metadata that's not provided by ApiExplorer
ParameterInfo parameterInfo = null;
PropertyInfo propertyInfo = null;

var customAttributes = Enumerable.Empty<object>();
if (apiParameterDescription.TryGetParameterInfo(apiDescription, out parameterInfo))
customAttributes = parameterInfo.GetCustomAttributes(true);
else if (apiParameterDescription.TryGetPropertyInfo(out propertyInfo))
customAttributes = propertyInfo.GetCustomAttributes(true);

var controllerParameterDescriptor = GetControllerParameterDescriptorOrNull(
apiDescription, apiParameterDescription);
var isRequired = customAttributes.Any(attr =>
new[] { typeof(RequiredAttribute), typeof(BindRequiredAttribute) }.Contains(attr.GetType()));

var parameter = (location == "body")
? new BodyParameter { Name = name, Schema = schema, Required = isRequired }
: CreateNonBodyParameter(name, location, schema, isRequired, schemaRegistry);

var filterContext = new ParameterFilterContext(
apiParameterDescription,
controllerParameterDescriptor,
schemaRegistry);
schemaRegistry,
parameterInfo,
propertyInfo);

foreach (var filter in _settings.ParameterFilters)
{
Expand Down Expand Up @@ -283,23 +304,6 @@ private Response CreateResponse(ApiResponseType apiResponseType, ISchemaRegistry
};
}

private ControllerParameterDescriptor GetControllerParameterDescriptorOrNull(
ApiDescription apiDescription,
ApiParameterDescription apiParameterDescription)
{
if (apiParameterDescription.ModelMetadata?.MetadataKind == ModelMetadataKind.Property)
return null;

var parameterDescriptor = apiDescription.ActionDescriptor.Parameters
.FirstOrDefault(paramDescriptor =>
{
return (apiParameterDescription.Name == paramDescriptor.BindingInfo?.BinderModelName)
|| (apiParameterDescription.Name == paramDescriptor.Name);
});

return parameterDescriptor as ControllerParameterDescriptor;
}

private static Dictionary<BindingSource, string> ParameterLocationMap = new Dictionary<BindingSource, string>
{
{ BindingSource.Form, "formData" },
Expand Down

0 comments on commit c977702

Please sign in to comment.