diff --git a/build.sbt b/build.sbt index 451b886c20..9239d40046 100644 --- a/build.sbt +++ b/build.sbt @@ -48,6 +48,7 @@ val exampleCases: List[(java.io.File, String, Boolean, List[String])] = List( (sampleResource("plain.json"), "tests.dtos", false, List.empty), (sampleResource("polymorphism.yaml"), "polymorphism", false, List.empty), (sampleResource("raw-response.yaml"), "raw", false, List.empty), + (sampleResource("security.yaml"), "security", false, List.empty), (sampleResource("server1.yaml"), "tracer", true, List.empty), (sampleResource("server2.yaml"), "tracer", true, List.empty), (sampleResource("pathological-parameters.yaml"), "pathological", false, List.empty) diff --git a/modules/codegen/src/main/scala/com/twilio/guardrail/generators/Java/DropwizardHelpers.scala b/modules/codegen/src/main/scala/com/twilio/guardrail/generators/Java/DropwizardHelpers.scala new file mode 100644 index 0000000000..f640720392 --- /dev/null +++ b/modules/codegen/src/main/scala/com/twilio/guardrail/generators/Java/DropwizardHelpers.scala @@ -0,0 +1,87 @@ +package com.twilio.guardrail.generators.Java + +import com.twilio.guardrail.{ SupportDefinition, Target } +import com.twilio.guardrail.generators.syntax.Java.loadSupportDefinitionFromString +import com.twilio.guardrail.languages.JavaLanguage + +object DropwizardHelpers { + def httpSecurityUtilsSupportDef: Target[SupportDefinition[JavaLanguage]] = loadSupportDefinitionFromString( + "HttpSecurityUtils", + """ + import java.nio.charset.StandardCharsets; + import java.util.Base64; + import java.util.Optional; + import java.util.Locale; + + public class HttpSecurityUtils { + public static class HttpBasicCredentials { + public static Optional parse(final Optional authHeader) { + return authHeader.flatMap(hdr -> { + final String[] parts = hdr.trim().split("\\s+"); + if (parts.length == 2) { + if ("basic".equals(parts[0].toLowerCase(Locale.US))) { + final String userPass = new String(Base64.getDecoder().decode(parts[1].trim()), StandardCharsets.UTF_8); + final String[] userPassParts = userPass.split(":", 2); + if (userPassParts.length == 2) { + return Optional.of(new HttpBasicCredentials(userPassParts[0], userPassParts[1])); + } else { + return Optional.of(new HttpBasicCredentials(userPassParts[0], "")); + } + } else { + return Optional.empty(); + } + } else { + return Optional.empty(); + } + }); + } + + private final String username; + private final String password; + + private HttpBasicCredentials(final String username, final String password) { + this.username = username; + this.password = password; + } + + public String getUsername() { + return this.username; + } + + public String getPassword() { + return this.password; + } + } + + public static class HttpBearerCredentials { + public static Optional parse(final Optional authHeader) { + return authHeader.flatMap(hdr -> { + final String[] parts = hdr.trim().split("\\s+"); + if (parts.length == 2) { + if ("bearer".equals(parts[0].toLowerCase(Locale.US))) { + return Optional.of(new HttpBearerCredentials(parts[1])); + } else { + return Optional.empty(); + } + } else { + return Optional.empty(); + } + }); + } + + private final String token; + + private HttpBearerCredentials(final String token) { + this.token = token; + } + + public String getToken() { + return this.token; + } + } + + private HttpSecurityUtils() {} + } + """ + ) +} diff --git a/modules/codegen/src/main/scala/com/twilio/guardrail/generators/Java/DropwizardServerGenerator.scala b/modules/codegen/src/main/scala/com/twilio/guardrail/generators/Java/DropwizardServerGenerator.scala index ed2053e7fb..00ebd81070 100644 --- a/modules/codegen/src/main/scala/com/twilio/guardrail/generators/Java/DropwizardServerGenerator.scala +++ b/modules/codegen/src/main/scala/com/twilio/guardrail/generators/Java/DropwizardServerGenerator.scala @@ -16,14 +16,16 @@ import com.twilio.guardrail.{ ADT, ClassDefinition, EnumDefinition, RandomType, import com.twilio.guardrail.extract.ServerRawResponse import com.twilio.guardrail.generators.ScalaParameters import com.twilio.guardrail.generators.syntax.Java._ +import com.twilio.guardrail.generators.syntax.RichString import com.twilio.guardrail.languages.JavaLanguage import com.twilio.guardrail.protocol.terms.{ Response, Responses } import com.twilio.guardrail.protocol.terms.server._ import com.twilio.guardrail.shims.OperationExt -import com.twilio.guardrail.terms.RouteMeta +import com.twilio.guardrail.terms.{ ApiKeySecurityScheme, HttpSecurityScheme, OAuth2SecurityScheme, OpenIdConnectSecurityScheme, RouteMeta, SecurityScheme } import io.swagger.v3.oas.models.Operation import io.swagger.v3.oas.models.PathItem.HttpMethod import io.swagger.v3.oas.models.responses.ApiResponse +import io.swagger.v3.oas.models.security.{ SecurityScheme => SwSecurityScheme } import java.util import scala.collection.JavaConverters._ import scala.language.existentials @@ -44,6 +46,9 @@ object DropwizardServerGenerator { private val RESPONSE_BUILDER_TYPE = JavaParser.parseClassOrInterfaceType("Response.ResponseBuilder") private val LOGGER_TYPE = JavaParser.parseClassOrInterfaceType("Logger") + private val HTTP_BASIC_CREDENTIALS_TYPE = JavaParser.parseClassOrInterfaceType("HttpSecurityUtils.HttpBasicCredentials") + private val HTTP_BEARER_CREDENTIALS_TYPE = JavaParser.parseClassOrInterfaceType("HttpSecurityUtils.HttpBearerCredentials") + private def removeEmpty(s: String): Option[String] = if (s.trim.isEmpty) None else Some(s.trim) private def splitPathComponents(s: String): List[String] = s.split("/").flatMap(removeEmpty).toList @@ -126,6 +131,112 @@ object DropwizardServerGenerator { ) } + case class SecurityParameters(routeParameters: List[Parameter], + routeStatements: List[Statement], + handlerParameters: List[Parameter], + handlerArgs: List[Expression]) + + type UnknownHttpAuthSchemeHandler = (Operation, String, SecurityScheme, String) => Target[SecurityParameters] + + def emptyUnknownHttpAuthSchemeHandler(operation: Operation, + schemeName: String, + securityScheme: SecurityScheme, + authScheme: String): Target[SecurityParameters] = + Target.raiseError(s"HTTP auth scheme $authScheme is not yet supported") + + def generateSecurityParams(operation: Operation, + securitySchemes: Map[String, SecurityScheme], + unknownHttpAuthSchemeHandler: UnknownHttpAuthSchemeHandler = emptyUnknownHttpAuthSchemeHandler): Target[SecurityParameters] = + Option(operation.getSecurity).toList + .flatMap(_.asScala) + .flatTraverse({ requirement => + requirement.asScala.toList.traverse({ + case (schemeName, _) => + securitySchemes + .get(schemeName) + .fold( + Target.raiseError[SecurityParameters](s"Operation '${operation.getOperationId} references undefined security scheme $schemeName") + )({ + case ApiKeySecurityScheme(name, in, typeName, _) => + for { + annotationName <- in match { + case SwSecurityScheme.In.QUERY => + Target.pure("QueryParam") + case SwSecurityScheme.In.HEADER => + Target.pure("HeaderParam") + case SwSecurityScheme.In.COOKIE => + Target.raiseError[String]("API Key security schemes stored in cookies are not yet supported") + } + rawParameterType <- typeName.fold(Target.pure(STRING_TYPE))(safeParseClassOrInterfaceType) + } yield { + val parameterType = if (rawParameterType.isOptional) rawParameterType else optionalType(rawParameterType) + val handlerParameter = new Parameter(util.EnumSet.of(FINAL), parameterType, new SimpleName(name.toCamelCase)) + val routeParameter = handlerParameter + .clone() + .addAnnotation(new SingleMemberAnnotationExpr(new Name(annotationName), new StringLiteralExpr(name))) + SecurityParameters(List(routeParameter), List.empty, List(handlerParameter), List(new NameExpr(name.toCamelCase))) + } + + case scheme @ HttpSecurityScheme(authScheme, _) => + authScheme match { + case "basic" => + val routeParameters = List( + new Parameter(util.EnumSet.of(FINAL), optionalType(STRING_TYPE), new SimpleName("httpBasicAuthorization")) + .addAnnotation(new SingleMemberAnnotationExpr(new Name("HeaderParam"), new StringLiteralExpr("Authorization"))) + ) + val handlerParameters = List( + new Parameter(util.EnumSet.of(FINAL), optionalType(HTTP_BASIC_CREDENTIALS_TYPE), new SimpleName("httpBasicCredentials")) + ) + val handlerArgs = List( + new MethodCallExpr( + new NameExpr("HttpSecurityUtils.HttpBasicCredentials"), + "parse", + new NodeList[Expression](new NameExpr("httpBasicAuthorization")) + ) + ) + Target.pure(SecurityParameters(routeParameters, List.empty, handlerParameters, handlerArgs)) + + case "bearer" => + val routeParameters = List( + new Parameter(util.EnumSet.of(FINAL), optionalType(STRING_TYPE), new SimpleName("httpBearerAuthorization")) + .addAnnotation(new SingleMemberAnnotationExpr(new Name("HeaderParam"), new StringLiteralExpr("Authorization"))) + ) + val handlerParameters = List( + new Parameter(util.EnumSet.of(FINAL), optionalType(HTTP_BEARER_CREDENTIALS_TYPE), new SimpleName("httpBearerCredentials")) + ) + val handlerArgs = List( + new MethodCallExpr( + new NameExpr("HttpSecurityUtils.HttpBearerCredentials"), + "parse", + new NodeList[Expression](new NameExpr("httpBearerAuthorization")) + ) + ) + Target.pure(SecurityParameters(routeParameters, List.empty, handlerParameters, handlerArgs)) + + case _ => + unknownHttpAuthSchemeHandler(operation, schemeName, scheme, authScheme) + } + + case _: OpenIdConnectSecurityScheme => + Target.raiseError("OpenID Connect is not yet supported") + + case _: OAuth2SecurityScheme => + Target.raiseError("OAuth2 is not yet supported") + }) + }) + }) + .map( + _.foldLeft(SecurityParameters(List.empty, List.empty, List.empty, List.empty))( + (accum, next) => + SecurityParameters( + routeParameters = accum.routeParameters ++ next.routeParameters, + routeStatements = accum.routeStatements ++ next.routeStatements, + handlerParameters = accum.handlerParameters ++ next.handlerParameters, + handlerArgs = accum.handlerArgs ++ next.handlerArgs + ) + ) + ) + def generateResponseSuperClass(name: String): Target[ClassOrInterfaceDeclaration] = { val cls = new ClassOrInterfaceDeclaration(util.EnumSet.of(ABSTRACT), false, name) cls.addField(PrimitiveType.intType, "statusCode", PRIVATE, FINAL) @@ -259,6 +370,7 @@ object DropwizardServerGenerator { parameters: ScalaParameters[JavaLanguage], responses: Responses[JavaLanguage], protocolElems: List[StrictProtocolElems[JavaLanguage]], + securitySchemes: Map[String, SecurityScheme], handlerName: String): Target[(MethodDeclaration, MethodDeclaration)] = { val operationId = operation.getOperationId parameters.parameters.foreach(p => p.param.setType(p.param.getType.unbox)) @@ -302,162 +414,175 @@ object DropwizardServerGenerator { parameter } - val methodParams: List[Parameter] = List( - (parameters.pathParams, "PathParam"), - (parameters.headerParams, "HeaderParam"), - (parameters.queryStringParams, "QueryParam"), - (parameters.formParams, if (consumes.contains(RouteMeta.MultipartFormData)) "FormDataParam" else "FormParam") - ).flatMap({ - case (params, annotationName) => - params.map(param => addParamAnnotation(param.param, annotationName, param.argName.value)) - }) ++ parameters.bodyParams.map(_.param) - - methodParams.foreach(method.addParameter) - method.addParameter( - new Parameter(util.EnumSet.of(FINAL), ASYNC_RESPONSE_TYPE, new SimpleName("asyncResponse")).addMarkerAnnotation("Suspended") - ) + for { + securityParameters <- generateSecurityParams(operation, securitySchemes) + } yield { + val methodParams: List[Parameter] = List( + (parameters.pathParams, "PathParam"), + (parameters.headerParams, "HeaderParam"), + (parameters.queryStringParams, "QueryParam"), + (parameters.formParams, if (consumes.contains(RouteMeta.MultipartFormData)) "FormDataParam" else "FormParam") + ).flatMap({ + case (params, annotationName) => + params.map(param => addParamAnnotation(param.param, annotationName, param.argName.value)) + }) ++ parameters.bodyParams.map(_.param) + + securityParameters.routeParameters.foreach(method.addParameter) + methodParams.foreach(method.addParameter) + method.addParameter( + new Parameter(util.EnumSet.of(FINAL), ASYNC_RESPONSE_TYPE, new SimpleName("asyncResponse")).addMarkerAnnotation("Suspended") + ) - val (responseType, resultResumeBody) = - ServerRawResponse(operation) - .filter(_ == true) - .fold({ - val responseName = s"${handlerName}.${operationId.capitalize}Response" - - val entitySetterIfTree = NonEmptyList - .fromList(responses.value.collect({ - case Response(statusCodeName, Some(_)) => statusCodeName - })) - .map(_.reverse.foldLeft[IfStmt](null)({ - case (nextIfTree, statusCodeName) => - val responseSubclassType = JavaParser.parseClassOrInterfaceType(s"${responseName}.${statusCodeName}") - new IfStmt( - new InstanceOfExpr(new NameExpr("result"), responseSubclassType), - new BlockStmt( - new NodeList( - new ExpressionStmt( - new MethodCallExpr( - new NameExpr("builder"), - "entity", - new NodeList[Expression]( - new MethodCallExpr( - new EnclosedExpr(new CastExpr(responseSubclassType, new NameExpr("result"))), - "getEntityBody" + val (responseType, resultResumeBody) = + ServerRawResponse(operation) + .filter(_ == true) + .fold({ + val responseName = s"${handlerName}.${operationId.capitalize}Response" + + val entitySetterIfTree = NonEmptyList + .fromList(responses.value.collect({ + case Response(statusCodeName, Some(_)) => statusCodeName + })) + .map(_.reverse.foldLeft[IfStmt](null)({ + case (nextIfTree, statusCodeName) => + val responseSubclassType = JavaParser.parseClassOrInterfaceType(s"${responseName}.${statusCodeName}") + new IfStmt( + new InstanceOfExpr(new NameExpr("result"), responseSubclassType), + new BlockStmt( + new NodeList( + new ExpressionStmt( + new MethodCallExpr( + new NameExpr("builder"), + "entity", + new NodeList[Expression]( + new MethodCallExpr( + new EnclosedExpr(new CastExpr(responseSubclassType, new NameExpr("result"))), + "getEntityBody" + ) ) ) ) ) - ) - ), - nextIfTree - ) - })) + ), + nextIfTree + ) + })) - ( - JavaParser.parseClassOrInterfaceType(responseName), ( - List[Statement]( - new ExpressionStmt( - new VariableDeclarationExpr( - new VariableDeclarator( - RESPONSE_BUILDER_TYPE, - "builder", - new MethodCallExpr(new NameExpr("Response"), - "status", - new NodeList[Expression](new MethodCallExpr(new NameExpr("result"), "getStatusCode"))) - ), - FINAL + JavaParser.parseClassOrInterfaceType(responseName), + ( + List[Statement]( + new ExpressionStmt( + new VariableDeclarationExpr( + new VariableDeclarator( + RESPONSE_BUILDER_TYPE, + "builder", + new MethodCallExpr(new NameExpr("Response"), + "status", + new NodeList[Expression](new MethodCallExpr(new NameExpr("result"), "getStatusCode"))) + ), + FINAL + ) + ) + ) ++ entitySetterIfTree ++ List( + new ExpressionStmt( + new MethodCallExpr(new NameExpr("asyncResponse"), "resume", new NodeList[Expression](new MethodCallExpr(new NameExpr("builder"), "build"))) ) ) - ) ++ entitySetterIfTree ++ List( + ).toNodeList + ) + })({ _ => + ( + RESPONSE_TYPE, + new NodeList( new ExpressionStmt( - new MethodCallExpr(new NameExpr("asyncResponse"), "resume", new NodeList[Expression](new MethodCallExpr(new NameExpr("builder"), "build"))) - ) - ) - ).toNodeList - ) - })({ _ => - ( - RESPONSE_TYPE, - new NodeList( - new ExpressionStmt( - new MethodCallExpr( - new NameExpr("asyncResponse"), - "resume", - new NodeList[Expression](new NameExpr("result")) + new MethodCallExpr( + new NameExpr("asyncResponse"), + "resume", + new NodeList[Expression](new NameExpr("result")) + ) ) ) ) - ) - }) + }) - val whenCompleteLambda = new LambdaExpr( - new NodeList( - new Parameter(util.EnumSet.of(FINAL), responseType, new SimpleName("result")), - new Parameter(util.EnumSet.of(FINAL), THROWABLE_TYPE, new SimpleName("err")) - ), - new BlockStmt( + val whenCompleteLambda = new LambdaExpr( new NodeList( - new IfStmt( - new BinaryExpr(new NameExpr("err"), new NullLiteralExpr, BinaryExpr.Operator.NOT_EQUALS), - new BlockStmt( - new NodeList( - new ExpressionStmt( - new MethodCallExpr( - new NameExpr("logger"), - "error", - new NodeList[Expression]( - new StringLiteralExpr(s"${handlerName}.${operationId} threw an exception ({}): {}"), - new MethodCallExpr(new MethodCallExpr(new NameExpr("err"), "getClass"), "getName"), - new MethodCallExpr(new NameExpr("err"), "getMessage"), - new NameExpr("err") + new Parameter(util.EnumSet.of(FINAL), responseType, new SimpleName("result")), + new Parameter(util.EnumSet.of(FINAL), THROWABLE_TYPE, new SimpleName("err")) + ), + new BlockStmt( + new NodeList( + new IfStmt( + new BinaryExpr(new NameExpr("err"), new NullLiteralExpr, BinaryExpr.Operator.NOT_EQUALS), + new BlockStmt( + new NodeList( + new ExpressionStmt( + new MethodCallExpr( + new NameExpr("logger"), + "error", + new NodeList[Expression]( + new StringLiteralExpr(s"${handlerName}.${operationId} threw an exception ({}): {}"), + new MethodCallExpr(new MethodCallExpr(new NameExpr("err"), "getClass"), "getName"), + new MethodCallExpr(new NameExpr("err"), "getMessage"), + new NameExpr("err") + ) ) - ) - ), - new ExpressionStmt( - new MethodCallExpr( - new NameExpr("asyncResponse"), - "resume", - new NodeList[Expression]( - new MethodCallExpr(new MethodCallExpr( - new NameExpr("Response"), - "status", - new NodeList[Expression](new IntegerLiteralExpr(500)) - ), - "build") + ), + new ExpressionStmt( + new MethodCallExpr( + new NameExpr("asyncResponse"), + "resume", + new NodeList[Expression]( + new MethodCallExpr(new MethodCallExpr( + new NameExpr("Response"), + "status", + new NodeList[Expression](new IntegerLiteralExpr(500)) + ), + "build") + ) ) ) ) - ) - ), - new BlockStmt(resultResumeBody) + ), + new BlockStmt(resultResumeBody) + ) ) - ) - ), - true - ) + ), + true + ) - val handlerCall = new MethodCallExpr( - new FieldAccessExpr(new ThisExpr, "handler"), - operationId, - new NodeList[Expression](methodParams.map(param => new NameExpr(param.getName.asString)): _*) - ) + val handlerCall = new MethodCallExpr( + new FieldAccessExpr(new ThisExpr, "handler"), + operationId, + (securityParameters.handlerArgs ++ methodParams.map(param => new NameExpr(param.getName.asString))).toNodeList + ) - method.setBody( - new BlockStmt( - new NodeList( - new ExpressionStmt(new MethodCallExpr(handlerCall, "whenComplete", new NodeList[Expression](whenCompleteLambda))) + method.setBody( + new BlockStmt( + ( + securityParameters.routeStatements :+ + new ExpressionStmt(new MethodCallExpr(handlerCall, "whenComplete", new NodeList[Expression](whenCompleteLambda))) + ).toNodeList ) ) - ) - val futureResponseType = completionStageType(responseType.clone()) - val handlerMethodSig = new MethodDeclaration(util.EnumSet.noneOf(classOf[Modifier]), futureResponseType, operationId) - (parameters.pathParams ++ parameters.headerParams ++ parameters.queryStringParams ++ parameters.formParams ++ parameters.bodyParams).foreach({ parameter => - handlerMethodSig.addParameter(parameter.param.clone()) - }) - handlerMethodSig.setBody(null) + val futureResponseType = completionStageType(responseType.clone()) + val handlerMethodSig = new MethodDeclaration(util.EnumSet.noneOf(classOf[Modifier]), futureResponseType, operationId) + (securityParameters.handlerParameters ++ + ( + parameters.pathParams ++ + parameters.headerParams ++ + parameters.queryStringParams ++ + parameters.formParams ++ + parameters.bodyParams + ).map(_.param)).foreach({ parameter => + handlerMethodSig.addParameter(parameter.clone()) + }) + handlerMethodSig.setBody(null) - Target.pure((method, handlerMethodSig)) + (method, handlerMethodSig) + } } object ServerTermInterp extends (ServerTerm[JavaLanguage, ?] ~> Target) { @@ -507,7 +632,7 @@ object DropwizardServerGenerator { processedRoutes <- routes .traverse({ case (_, _, RouteMeta(path, httpMethod, operation), parameters, responses) => - generateRoute(operation, path, commonPathPrefix, httpMethod, parameters, responses, protocolElems, handlerName) + generateRoute(operation, path, commonPathPrefix, httpMethod, parameters, responses, protocolElems, securitySchemes, handlerName) }) (routeMethods, handlerMethodSigs) = processedRoutes.unzip @@ -577,7 +702,8 @@ object DropwizardServerGenerator { shower <- SerializationHelpers.showerSupportDef - jersey <- SerializationHelpers.guardrailJerseySupportDef + jersey <- SerializationHelpers.guardrailJerseySupportDef + security <- DropwizardHelpers.httpSecurityUtilsSupportDef } yield { def httpMethodAnnotation(name: String): SupportDefinition[JavaLanguage] = { val annotationDecl = new AnnotationDeclaration(util.EnumSet.of(PUBLIC), name) @@ -593,6 +719,7 @@ object DropwizardServerGenerator { List( shower, jersey, + security, httpMethodAnnotation("PATCH"), httpMethodAnnotation("TRACE") ) diff --git a/modules/sample-dropwizard/src/test/java/core/Dropwizard/DropwizardSecurityTest.java b/modules/sample-dropwizard/src/test/java/core/Dropwizard/DropwizardSecurityTest.java new file mode 100644 index 0000000000..fd2622aa5a --- /dev/null +++ b/modules/sample-dropwizard/src/test/java/core/Dropwizard/DropwizardSecurityTest.java @@ -0,0 +1,160 @@ +package core.Dropwizard; + +import io.dropwizard.testing.junit.ResourceTestRule; +import org.glassfish.jersey.test.TestProperties; +import org.glassfish.jersey.test.grizzly.GrizzlyTestContainerFactory; +import org.junit.ClassRule; +import org.junit.Test; +import security.server.dropwizard.Handler; +import security.server.dropwizard.HttpSecurityUtils; +import security.server.dropwizard.Resource; + +import java.nio.charset.StandardCharsets; +import java.util.Base64; +import java.util.Optional; +import java.util.concurrent.CompletionStage; + +import static org.assertj.core.api.Assertions.assertThat; +import static java.util.concurrent.CompletableFuture.completedFuture; + +public class DropwizardSecurityTest { + static { + System.setProperty(TestProperties.CONTAINER_PORT, "0"); + } + + private static final String API_KEY = "sekrit-api-key"; + private static final String BASIC_USERNAME = "some-user"; + private static final String BASIC_PASSWORD = "my-sekrit-password"; + private static final String BEARER_TOKEN = "my-sekret-token"; + + private static final Handler handler = new Handler() { + @Override + public CompletionStage apiKeyHeader(Optional xApiKey, Optional foo) { + assertThat(xApiKey.get()).isEqualTo(API_KEY); + return completedFuture(ApiKeyHeaderResponse.Ok); + } + + @Override + public CompletionStage apiKeyQuery(Optional apiKey, Optional foo) { + assertThat(apiKey.get()).isEqualTo(API_KEY); + return completedFuture(ApiKeyQueryResponse.Ok); + } + + @Override + public CompletionStage httpBasicAuth(Optional httpBasicCredentials, Optional foo) { + assertThat(httpBasicCredentials.get().getUsername()).isEqualTo(BASIC_USERNAME); + assertThat(httpBasicCredentials.get().getPassword()).isEqualTo(BASIC_PASSWORD); + return completedFuture(HttpBasicAuthResponse.Ok); + } + + @Override + public CompletionStage httpBearerAuth(Optional httpBearerCredentials, Optional foo) { + assertThat(httpBearerCredentials.get().getToken()).isEqualTo(BEARER_TOKEN); + return completedFuture(HttpBearerAuthResponse.Ok); + } + + @Override + public CompletionStage multipleAuthAnd(Optional xApiKey, Optional httpBearerCredentials) { + assertThat(xApiKey.get()).isEqualTo(API_KEY); + assertThat(httpBearerCredentials.get().getToken()).isEqualTo(BEARER_TOKEN); + return completedFuture(MultipleAuthAndResponse.Ok); + } + + @Override + public CompletionStage multipleAuthOr(Optional apiKey, Optional httpBasicCredentials, Optional httpBearerCredentials) { + assertThat(apiKey.get()).isEqualTo(API_KEY); + assertThat(httpBasicCredentials.get().getUsername()).isEqualTo(BASIC_USERNAME); + assertThat(httpBasicCredentials.get().getPassword()).isEqualTo(BASIC_PASSWORD); + assertThat(httpBearerCredentials.isPresent()).isFalse(); + return completedFuture(MultipleAuthOrResponse.Ok); + } + }; + + @ClassRule + public static final ResourceTestRule resources = ResourceTestRule.builder() + .setTestContainerFactory(new GrizzlyTestContainerFactory()) + .addResource(new Resource(handler)) + .build(); + + @Test + public void testApiKeyQueryAuth() { + assertThat( + resources + .target("/api-key-query") + .queryParam("ApiKey", API_KEY) + .request() + .get() + .getStatus() + ).isEqualTo(200); + } + + @Test + public void testApiKeyHeaderAuth() { + assertThat( + resources + .target("/api-key-header") + .request() + .header("x-api-key", API_KEY) + .get() + .getStatus() + ).isEqualTo(200); + } + + @Test + public void testHttpBasicAuth() { + assertThat( + resources + .target("/http-basic") + .request() + .header("authorization", createHttpBasicHeader(BASIC_USERNAME, BASIC_PASSWORD)) + .get() + .getStatus() + ).isEqualTo(200); + } + + @Test + public void testHttpBearerAuth() { + assertThat( + resources + .target("/http-bearer") + .request() + .header("authorization", createHttpBearerHeader(BEARER_TOKEN)) + .get() + .getStatus() + ).isEqualTo(200); + } + + @Test + public void testMultipleOr() { + assertThat( + resources + .target("/multiple-or") + .queryParam("ApiKey", API_KEY) + .request() + .header("authorization", createHttpBasicHeader(BASIC_USERNAME, BASIC_PASSWORD)) + .get() + .getStatus() + ).isEqualTo(200); + } + + @Test + public void testMultipleAnd() { + assertThat( + resources + .target("/multiple-and") + .request() + .header("x-api-key", API_KEY) + .header("authorization", createHttpBearerHeader(BEARER_TOKEN)) + .get() + .getStatus() + ).isEqualTo(200); + } + + private static String createHttpBasicHeader(final String username, final String password) { + return "Basic " + Base64.getEncoder().encodeToString((username + ":" + password).getBytes(StandardCharsets.UTF_8)); + } + + private static String createHttpBearerHeader(final String token) { + return "Bearer " + token; + } +} diff --git a/modules/sample/src/main/resources/petstore.json b/modules/sample/src/main/resources/petstore.json index 0aa45b7e6f..29a04aead0 100644 --- a/modules/sample/src/main/resources/petstore.json +++ b/modules/sample/src/main/resources/petstore.json @@ -963,13 +963,7 @@ }, "securityDefinitions": { "petstore_auth": { - "type": "oauth2", - "authorizationUrl": "http://petstore.swagger.io/oauth/dialog", - "flow": "implicit", - "scopes": { - "write:pets": "modify pets in your account", - "read:pets": "read your pets" - } + "type": "basic" }, "api_key": { "type": "apiKey", diff --git a/modules/sample/src/main/resources/security.yaml b/modules/sample/src/main/resources/security.yaml new file mode 100644 index 0000000000..04223e7636 --- /dev/null +++ b/modules/sample/src/main/resources/security.yaml @@ -0,0 +1,90 @@ +openapi: 3.0.1 +info: + name: Whatever + version: 1.0.0 +paths: + /api-key-header: + get: + security: + - ApiKeyHeader: [] + operationId: apiKeyHeader + parameters: + - name: foo + in: query + schema: + type: integer + format: int32 + responses: + 200: {} + /api-key-query: + get: + security: + - ApiKeyQuery: [] + operationId: apiKeyQuery + parameters: + - name: foo + in: query + schema: + type: integer + format: int32 + responses: + 200: {} + /http-basic: + get: + security: + - HttpBasicAuth: [] + operationId: httpBasicAuth + parameters: + - name: foo + in: query + schema: + type: integer + format: int32 + responses: + 200: {} + /http-bearer: + get: + security: + - HttpBearerAuth: [] + operationId: httpBearerAuth + parameters: + - name: foo + in: query + schema: + type: integer + format: int32 + responses: + 200: {} + /multiple-or: + get: + security: + - ApiKeyQuery: [] + - HttpBasicAuth: [] + - HttpBearerAuth: [] + operationId: multipleAuthOr + responses: + 200: {} + /multiple-and: + get: + security: + - ApiKeyHeader: [] + HttpBearerAuth: [] + operationId: multipleAuthAnd + responses: + 200: {} +components: + securitySchemes: + ApiKeyHeader: + type: apiKey + in: header + name: x-api-key + ApiKeyQuery: + type: apiKey + in: query + name: ApiKey + HttpBasicAuth: + type: http + scheme: basic + HttpBearerAuth: + type: http + scheme: bearer