Skip to content

Commit

Permalink
Add ability to set custom SecuritytSchemes to OpenApiGroup (#1440)
Browse files Browse the repository at this point in the history
Fixed #1439
  • Loading branch information
altro3 committed Feb 16, 2024
1 parent ad0c126 commit 10908d8
Show file tree
Hide file tree
Showing 6 changed files with 275 additions and 108 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@

import io.micronaut.context.annotation.AliasFor;
import io.swagger.v3.oas.annotations.OpenAPIDefinition;
import io.swagger.v3.oas.annotations.security.SecurityScheme;

import static java.lang.annotation.RetentionPolicy.SOURCE;

Expand Down Expand Up @@ -54,5 +55,12 @@
/**
* @return OpenAPI object describing information about group.
*/
OpenAPIDefinition info();
OpenAPIDefinition info() default @OpenAPIDefinition;

/**
* @return Security schemes for OpenAPI group.
*
* @since 6.6.0
*/
SecurityScheme[] securitySchemes() default {};
}
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@
import io.swagger.v3.oas.annotations.media.Schema.AdditionalPropertiesValue;
import io.swagger.v3.oas.annotations.responses.ApiResponse;
import io.swagger.v3.oas.annotations.security.OAuthScope;
import io.swagger.v3.oas.annotations.security.SecurityScheme;
import io.swagger.v3.oas.annotations.servers.Server;
import io.swagger.v3.oas.annotations.servers.ServerVariable;
import io.swagger.v3.oas.models.Components;
Expand All @@ -114,7 +115,6 @@
import io.swagger.v3.oas.models.media.Schema;
import io.swagger.v3.oas.models.media.StringSchema;
import io.swagger.v3.oas.models.security.SecurityRequirement;
import io.swagger.v3.oas.models.security.SecurityScheme;
import io.swagger.v3.oas.models.tags.Tag;

import com.fasterxml.jackson.annotation.JsonAnySetter;
Expand All @@ -136,8 +136,8 @@
import static io.micronaut.openapi.visitor.ConfigUtils.isJsonViewDefaultInclusion;
import static io.micronaut.openapi.visitor.ContextUtils.warn;
import static io.micronaut.openapi.visitor.ConvertUtils.parseJsonString;
import static io.micronaut.openapi.visitor.ConvertUtils.resolveExtensions;
import static io.micronaut.openapi.visitor.ConvertUtils.setDefaultValueObject;
import static io.micronaut.openapi.visitor.ConvertUtils.toTupleSubMap;
import static io.micronaut.openapi.visitor.ElementUtils.isElementNotNullable;
import static io.micronaut.openapi.visitor.ElementUtils.isFileUpload;
import static io.micronaut.openapi.visitor.ElementUtils.isNullable;
Expand Down Expand Up @@ -279,11 +279,11 @@ Map<String, List<PathItem>> resolvePathItems(VisitorContext context, List<UriMat
openAPI.setPaths(paths);
}

Map<String, List<PathItem>> resultPathItemsMap = new HashMap<>();
var resultPathItemsMap = new HashMap<String, List<PathItem>>();

for (UriMatchTemplate matchTemplate : matchTemplates) {

StringBuilder result = new StringBuilder();
var result = new StringBuilder();

boolean optionalPathVar = false;
boolean varProcess = false;
Expand Down Expand Up @@ -347,10 +347,10 @@ Map<String, List<PathItem>> resolvePathItems(VisitorContext context, List<UriMat
resultPath = contextPath + resultPath;
}

Map<Integer, String> finalPaths = new HashMap<>();
var finalPaths = new HashMap<Integer, String>();
finalPaths.put(-1, resultPath);
if (CollectionUtils.isNotEmpty(matchTemplate.getVariables())) {
List<String> optionalVars = new ArrayList<>();
var optionalVars = new ArrayList<String>();
// need check not required path variables
for (UriMatchVariable var : matchTemplate.getVariables()) {
if (var.isQuery() || !var.isOptional() || var.isExploded()) {
Expand Down Expand Up @@ -388,13 +388,14 @@ Map<String, List<PathItem>> resolvePathItems(VisitorContext context, List<UriMat
}

private List<String> addOptionalVars(List<String> paths, String var, int level) {
List<String> additionalPaths = new ArrayList<>(paths);
var additionalPaths = new ArrayList<>(paths);
if (paths.isEmpty()) {
additionalPaths.add("/{" + var + '}');
} else {
for (String path : paths) {
additionalPaths.add(path + "/{" + var + '}');
}
return additionalPaths;
}

for (String path : paths) {
additionalPaths.add(path + "/{" + var + '}');
}
return additionalPaths;
}
Expand All @@ -409,7 +410,7 @@ private List<String> addOptionalVars(List<String> paths, String var, int level)
* @return The map
*/
protected Map<CharSequence, Object> toValueMap(Map<CharSequence, Object> values, VisitorContext context, @Nullable ClassElement jsonViewClass) {
Map<CharSequence, Object> newValues = new HashMap<>(values.size());
var newValues = new HashMap<CharSequence, Object>(values.size());
for (Map.Entry<CharSequence, Object> entry : values.entrySet()) {
CharSequence key = entry.getKey();
Object value = entry.getValue();
Expand Down Expand Up @@ -723,19 +724,6 @@ private Map<CharSequence, Object> resolveAnnotationValues(VisitorContext context
return valueMap;
}

private Map<String, String> toTupleSubMap(Object[] a, String entryKey, String entryValue) {
Map<String, String> params = new LinkedHashMap<>();
for (Object o : a) {
AnnotationValue<?> sv = (AnnotationValue<?>) o;
final Optional<String> n = sv.stringValue(entryKey);
final Optional<String> expr = sv.stringValue(entryValue);
if (n.isPresent() && expr.isPresent()) {
params.put(n.get(), expr.get());
}
}
return params;
}

private boolean isTypeNullable(ClassElement type) {
return type.isAssignable(Optional.class);
}
Expand Down Expand Up @@ -2638,84 +2626,16 @@ private Schema<?> getPrimitiveType(ClassElement type, String typeName) {
}

/**
* Processes {@link io.swagger.v3.oas.annotations.security.SecurityScheme}
* Processes {@link SecurityScheme}
* annotations.
*
* @param element The element
* @param context The visitor context
*/
protected void processSecuritySchemes(ClassElement element, VisitorContext context) {
final List<AnnotationValue<io.swagger.v3.oas.annotations.security.SecurityScheme>> values = element
.getAnnotationValuesByType(io.swagger.v3.oas.annotations.security.SecurityScheme.class);
final OpenAPI openAPI = Utils.resolveOpenApi(context);
for (AnnotationValue<io.swagger.v3.oas.annotations.security.SecurityScheme> securityRequirementAnnotationValue : values) {

final Map<CharSequence, Object> map = toValueMap(securityRequirementAnnotationValue.getValues(), context, null);

securityRequirementAnnotationValue.stringValue("name")
.ifPresent(name -> {
if (map.containsKey("paramName")) {
map.put("name", map.remove("paramName"));
}

Utils.normalizeEnumValues(map, CollectionUtils.mapOf("type", SecurityScheme.Type.class, "in", SecurityScheme.In.class));

String type = (String) map.get("type");
if (!SecurityScheme.Type.APIKEY.toString().equals(type)) {
removeAndWarnSecSchemeProp(map, "name", context, false);
removeAndWarnSecSchemeProp(map, "in", context);
}
if (!SecurityScheme.Type.OAUTH2.toString().equals(type)) {
removeAndWarnSecSchemeProp(map, "flows", context);
}
if (!SecurityScheme.Type.OPENIDCONNECT.toString().equals(type)) {
removeAndWarnSecSchemeProp(map, "openIdConnectUrl", context);
}
if (!SecurityScheme.Type.HTTP.toString().equals(type)) {
removeAndWarnSecSchemeProp(map, "scheme", context);
removeAndWarnSecSchemeProp(map, "bearerFormat", context);
}

if (SecurityScheme.Type.HTTP.toString().equals(type)) {
if (!map.containsKey("scheme")) {
warn("Can't use http security scheme without 'scheme' property", context);
} else if (!map.get("scheme").equals("bearer") && map.containsKey("bearerFormat")) {
warn("Should NOT have a `bearerFormat` property without `scheme: bearer` being set", context);
}
}

if (map.containsKey("ref") || map.containsKey("$ref")) {
Object ref = map.get("ref");
if (ref == null) {
ref = map.get("$ref");
}
map.clear();
map.put("$ref", ref);
}

try {
JsonNode node = toJson(map, context, null);
SecurityScheme securityScheme = ConvertUtils.treeToValue(node, SecurityScheme.class, context);
if (securityScheme != null) {
resolveExtensions(node).ifPresent(extensions -> BeanMap.of(securityScheme).put("extensions", extensions));
resolveComponents(openAPI).addSecuritySchemes(name, securityScheme);
}
} catch (JsonProcessingException e) {
// ignore
}
});
}
}

private void removeAndWarnSecSchemeProp(Map<CharSequence, Object> map, String prop, VisitorContext context) {
removeAndWarnSecSchemeProp(map, prop, context, true);
}

private void removeAndWarnSecSchemeProp(Map<CharSequence, Object> map, String prop, VisitorContext context, boolean withWarn) {
if (map.containsKey(prop) && withWarn) {
warn("'" + prop + "' property can't set for securityScheme with type " + map.get("type") + ". Skip it", context);
}
map.remove(prop);
var values = element.getAnnotationValuesByType(SecurityScheme.class);
final OpenAPI openApi = Utils.resolveOpenApi(context);
ConvertUtils.addSecuritySchemes(openApi, values, context);
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
import io.micronaut.core.beans.BeanMap;
import io.micronaut.core.util.ArrayUtils;
import io.micronaut.core.util.CollectionUtils;
import io.micronaut.core.util.StringUtils;
import io.micronaut.inject.ast.ClassElement;
import io.micronaut.inject.ast.Element;
import io.micronaut.inject.ast.ElementQuery;
Expand All @@ -56,13 +57,16 @@
import io.micronaut.inject.visitor.VisitorContext;
import io.micronaut.openapi.swagger.core.util.PrimitiveType;
import io.swagger.v3.oas.annotations.extensions.Extension;
import io.swagger.v3.oas.annotations.security.OAuthScope;
import io.swagger.v3.oas.annotations.servers.Server;
import io.swagger.v3.oas.annotations.servers.ServerVariable;
import io.swagger.v3.oas.models.OpenAPI;
import io.swagger.v3.oas.models.media.Content;
import io.swagger.v3.oas.models.media.MediaType;
import io.swagger.v3.oas.models.media.Schema;
import io.swagger.v3.oas.models.responses.ApiResponse;
import io.swagger.v3.oas.models.security.SecurityRequirement;
import io.swagger.v3.oas.models.security.SecurityScheme;

import com.fasterxml.jackson.annotation.JsonValue;
import com.fasterxml.jackson.core.JsonProcessingException;
Expand All @@ -75,6 +79,7 @@
import static io.micronaut.openapi.visitor.ContextUtils.warn;
import static io.micronaut.openapi.visitor.SchemaUtils.TYPE_OBJECT;
import static io.micronaut.openapi.visitor.SchemaUtils.processExtensions;
import static io.micronaut.openapi.visitor.Utils.resolveComponents;

/**
* Convert utilities methods.
Expand Down Expand Up @@ -124,7 +129,7 @@ public static JsonNode toJson(Map<CharSequence, Object> values, VisitorContext c
}

public static Map<CharSequence, Object> toValueMap(Map<CharSequence, Object> values, VisitorContext context) {
Map<CharSequence, Object> newValues = new HashMap<>(values.size());
var newValues = new HashMap<CharSequence, Object>(values.size());
for (Map.Entry<CharSequence, Object> entry : values.entrySet()) {
CharSequence key = entry.getKey();
Object value = entry.getValue();
Expand Down Expand Up @@ -171,6 +176,9 @@ public static Map<CharSequence, Object> toValueMap(Map<CharSequence, Object> val
servers.add(variables);
}
newValues.put(key, servers);
} else if (OAuthScope.class.getName().equals(annotationName)) {
Map<String, String> params = toTupleSubMap(a, "name", "description");
newValues.put(key, params);
} else if (ServerVariable.class.getName().equals(annotationName)) {
Map<String, Map<CharSequence, Object>> variables = new LinkedHashMap<>();
for (Object o : a) {
Expand Down Expand Up @@ -344,6 +352,80 @@ public static Optional<Map<String, Object>> resolveExtensions(JsonNode jn) {
return Optional.empty();
}

public static void addSecuritySchemes(OpenAPI openApi,
List<AnnotationValue<io.swagger.v3.oas.annotations.security.SecurityScheme>> values,
VisitorContext context) {
for (var securityRequirementAnnValue : values) {

final Map<CharSequence, Object> map = toValueMap(securityRequirementAnnValue.getValues(), context);

var name = securityRequirementAnnValue.stringValue("name").orElse(null);
if (StringUtils.isEmpty(name)) {
continue;
}
if (map.containsKey("paramName")) {
map.put("name", map.remove("paramName"));
}

Utils.normalizeEnumValues(map, CollectionUtils.mapOf("type", SecurityScheme.Type.class, "in", SecurityScheme.In.class));

String type = (String) map.get("type");
if (!SecurityScheme.Type.APIKEY.toString().equals(type)) {
removeAndWarnSecSchemeProp(map, "name", context, false);
removeAndWarnSecSchemeProp(map, "in", context);
}
if (!SecurityScheme.Type.OAUTH2.toString().equals(type)) {
removeAndWarnSecSchemeProp(map, "flows", context);
}
if (!SecurityScheme.Type.OPENIDCONNECT.toString().equals(type)) {
removeAndWarnSecSchemeProp(map, "openIdConnectUrl", context);
}
if (!SecurityScheme.Type.HTTP.toString().equals(type)) {
removeAndWarnSecSchemeProp(map, "scheme", context);
removeAndWarnSecSchemeProp(map, "bearerFormat", context);
}

if (SecurityScheme.Type.HTTP.toString().equals(type)) {
if (!map.containsKey("scheme")) {
warn("Can't use http security scheme without 'scheme' property", context);
} else if (!map.get("scheme").equals("bearer") && map.containsKey("bearerFormat")) {
warn("Should NOT have a `bearerFormat` property without `scheme: bearer` being set", context);
}
}

if (map.containsKey("ref") || map.containsKey("$ref")) {
Object ref = map.get("ref");
if (ref == null) {
ref = map.get("$ref");
}
map.clear();
map.put("$ref", ref);
}

try {
JsonNode node = toJson(map, context);
SecurityScheme securityScheme = ConvertUtils.treeToValue(node, SecurityScheme.class, context);
if (securityScheme != null) {
resolveExtensions(node).ifPresent(extensions -> BeanMap.of(securityScheme).put("extensions", extensions));
resolveComponents(openApi).addSecuritySchemes(name, securityScheme);
}
} catch (JsonProcessingException e) {
// ignore
}
}
}

private static void removeAndWarnSecSchemeProp(Map<CharSequence, Object> map, String prop, VisitorContext context) {
removeAndWarnSecSchemeProp(map, prop, context, true);
}

private static void removeAndWarnSecSchemeProp(Map<CharSequence, Object> map, String prop, VisitorContext context, boolean withWarn) {
if (map.containsKey(prop) && withWarn) {
warn("'" + prop + "' property can't set for securityScheme with type " + map.get("type") + ". Skip it", context);
}
map.remove(prop);
}

/**
* Maps annotation value to {@link io.swagger.v3.oas.annotations.security.SecurityRequirement}.
* Correct format is:
Expand Down Expand Up @@ -558,4 +640,17 @@ public static Object parseByTypeAndFormat(String valueStr, String type, String f

return valueStr;
}

public static Map<String, String> toTupleSubMap(Object[] a, String entryKey, String entryValue) {
var params = new LinkedHashMap<String, String>();
for (Object o : a) {
AnnotationValue<?> sv = (AnnotationValue<?>) o;
final Optional<String> n = sv.stringValue(entryKey);
final Optional<String> expr = sv.stringValue(entryValue);
if (n.isPresent() && expr.isPresent()) {
params.put(n.get(), expr.get());
}
}
return params;
}
}
Loading

0 comments on commit 10908d8

Please sign in to comment.