diff --git a/modules/codegen/src/main/scala/com/twilio/guardrail/generators/Java/AsyncHttpClientClientGenerator.scala b/modules/codegen/src/main/scala/com/twilio/guardrail/generators/Java/AsyncHttpClientClientGenerator.scala index 7592999baf..1085d7eec0 100644 --- a/modules/codegen/src/main/scala/com/twilio/guardrail/generators/Java/AsyncHttpClientClientGenerator.scala +++ b/modules/codegen/src/main/scala/com/twilio/guardrail/generators/Java/AsyncHttpClientClientGenerator.scala @@ -12,7 +12,7 @@ import com.github.javaparser.ast.body.{ClassOrInterfaceDeclaration, MethodDeclar import com.github.javaparser.ast.expr.{MethodCallExpr, NameExpr, _} import com.github.javaparser.ast.stmt._ import com.twilio.guardrail.SwaggerUtil.jpaths -import com.twilio.guardrail.generators.Response +import com.twilio.guardrail.generators.{Response, ScalaParameter} import com.twilio.guardrail.generators.syntax.Java._ import com.twilio.guardrail.languages.JavaLanguage import com.twilio.guardrail.protocol.terms.client._ @@ -20,7 +20,7 @@ import com.twilio.guardrail.terms.RouteMeta import com.twilio.guardrail.{RenderedClientOperation, StaticDefns, Target} import java.net.URI import java.util -import java.util.Locale +import javax.lang.model.`type`.PrimitiveType object AsyncHttpClientClientGenerator { private val URI_TYPE = JavaParser.parseClassOrInterfaceType("URI") @@ -31,6 +31,7 @@ object AsyncHttpClientClientGenerator { private val REQUEST_BUILDER_TYPE = JavaParser.parseClassOrInterfaceType("RequestBuilder") private val REQUEST_TYPE = JavaParser.parseClassOrInterfaceType("Request") private val RESPONSE_TYPE = JavaParser.parseClassOrInterfaceType("Response") + private val STRING_PART_TYPE = JavaParser.parseClassOrInterfaceType("StringPart") private val OBJECT_MAPPER_TYPE = JavaParser.parseClassOrInterfaceType("ObjectMapper") private val BUILDER_TYPE = JavaParser.parseClassOrInterfaceType("Builder") private val MARSHALLING_EXCEPTION_TYPE = JavaParser.parseClassOrInterfaceType("MarshallingException") @@ -42,6 +43,82 @@ object AsyncHttpClientClientGenerator { completionStageType.setTypeArguments(RESPONSE_TYPE) )) + private def showParam(param: ScalaParameter[JavaLanguage], overrideParamName: Option[String] = None): Expression = { + val paramName = overrideParamName.getOrElse(param.paramName.asString) + + def doShow(tpe: Type): Expression = tpe match { + case _: PrimitiveType => + new MethodCallExpr(new NameExpr("String"), "valueOf", new NodeList[Expression](new NameExpr(paramName))) + case cls: ClassOrInterfaceType if cls.isOptional => + doShow(cls.containedType) + case cls: ClassOrInterfaceType if cls.isBoxedType => + new MethodCallExpr(new NameExpr(paramName), "toString") + case cls: ClassOrInterfaceType if cls.isNamed("List") => + doShow(cls.containedType) + case cls: ClassOrInterfaceType if cls.getName.asString() == "String" => + new NameExpr(paramName) + case _: ClassOrInterfaceType => + // FIXME: this will cover our autogenerated enum types, but it would be really nice if we could at identify + // our enum times via a list of some sort. this will fail to compile in most non-enum cases, but there's + // the possibility that this could generate incorrect but compilable code. + new MethodCallExpr(new NameExpr(paramName), "getValue") + case _: VoidType => + new NullLiteralExpr + case other => + println(s"WARN: Unhandled arg type ${other.getClass.getName} for arg typed ${other.name} ${param.paramName}") + new NameExpr("UNSUPPORTED_PARAMETER_TYPE_PLEASE_FILE_AN_ISSUE") + } + + doShow(param.argType) + } + + private def optionIfPresent(optionVarType: Type, optionVarName: String, innerStatement: Statement): Statement = { + new ExpressionStmt(new MethodCallExpr(new NameExpr(optionVarName), "ifPresent", new NodeList[Expression]( + new LambdaExpr(new NodeList(new Parameter(util.EnumSet.of(FINAL), optionVarType, new SimpleName("arg"))), + innerStatement, + true + )) + )) + } + + private def generateBuilderMethodCalls(params: List[ScalaParameter[JavaLanguage]], builderMethodName: String): List[Statement] = { + val needsMultipart = params.exists(_.isFile) + params.map({ param => + val finalMethodName = if (needsMultipart) "addBodyPart" else builderMethodName + val argName = if (param.required) param.paramName.asString else "arg" + val containedType = param.argType.containedType + val isList = if (param.required) param.argType.isNamed("List") else containedType.isNamed("List") + val listType = if (param.required) containedType else containedType.containedType + + val makeArgList: String => NodeList[Expression] = name => + if (containedType.isNamed("FilePart") || listType.isNamed("FilePart")) { + new NodeList[Expression](new NameExpr(name)) + } else if (needsMultipart) { + new NodeList[Expression](new ObjectCreationExpr(null, STRING_PART_TYPE, new NodeList(showParam(param, Some(name))))) + } else { + new NodeList[Expression](new StringLiteralExpr(param.argName.value), showParam(param, Some(name))) + } + + val builderStatement: Statement = if (isList) { + new ForEachStmt( + new VariableDeclarationExpr(listType, "member", FINAL), + new NameExpr(argName), + new BlockStmt(new NodeList( + new ExpressionStmt(new MethodCallExpr(new NameExpr("builder"), finalMethodName, makeArgList("member"))) + )) + ) + } else { + new ExpressionStmt(new MethodCallExpr(new NameExpr("builder"), finalMethodName, makeArgList(argName))) + } + + if (param.required) { + builderStatement + } else { + optionIfPresent(containedType, param.paramName.asString, builderStatement) + } + }) + } + object ClientTermInterp extends (ClientTerm[JavaLanguage, ?] ~> Target) { def apply[T](term: ClientTerm[JavaLanguage, T]): Target[T] = term match { case GenerateClientOperation(_, RouteMeta(pathStr, httpMethod, operation), methodName, tracing, parameters, responses) => @@ -71,13 +148,19 @@ object AsyncHttpClientClientGenerator { "setUrl", new NodeList[Expression](pathExpr) ) + val builderMethodCalls: List[Statement] = List( + generateBuilderMethodCalls(parameters.queryStringParams, "addQueryParam"), + generateBuilderMethodCalls(parameters.formParams, "addFormParam"), + generateBuilderMethodCalls(parameters.headerParams, "addHeader") + ).flatten + val httpMethodCallExpr = new MethodCallExpr( new FieldAccessExpr(new ThisExpr, "httpClient"), "apply", new NodeList[Expression](new MethodCallExpr(new NameExpr("builder"), "build")) ) val requestCall = new MethodCallExpr(httpMethodCallExpr, "thenApply", new NodeList[Expression]( - new LambdaExpr(new NodeList(new Parameter(RESPONSE_TYPE, "response")), new BlockStmt(new NodeList( + new LambdaExpr(new NodeList(new Parameter(util.EnumSet.of(FINAL), RESPONSE_TYPE, new SimpleName("response"))), new BlockStmt(new NodeList( new SwitchStmt(new MethodCallExpr(new NameExpr("response"), "getStatusCode"), new NodeList( responses.value.map(response => new SwitchEntryStmt(new IntegerLiteralExpr(response.statusCode), new NodeList(response.value match { case None => new ReturnStmt(new ObjectCreationExpr(null, JavaParser.parseClassOrInterfaceType(s"${responseParentName}.${response.statusCodeName.asString}"), new NodeList())) @@ -120,8 +203,9 @@ object AsyncHttpClientClientGenerator { )) method.setBody(new BlockStmt(new NodeList( - new ExpressionStmt(requestBuilder), - new ReturnStmt(requestCall) + new ExpressionStmt(requestBuilder) +: + builderMethodCalls :+ + new ReturnStmt(requestCall): _* ))) RenderedClientOperation[JavaLanguage](method, List.empty) diff --git a/modules/codegen/src/main/scala/com/twilio/guardrail/generators/Java/JacksonGenerator.scala b/modules/codegen/src/main/scala/com/twilio/guardrail/generators/Java/JacksonGenerator.scala index dc2e406853..8a671112c6 100644 --- a/modules/codegen/src/main/scala/com/twilio/guardrail/generators/Java/JacksonGenerator.scala +++ b/modules/codegen/src/main/scala/com/twilio/guardrail/generators/Java/JacksonGenerator.scala @@ -34,7 +34,7 @@ object JacksonGenerator { private def sortParams(params: List[ProtocolParameter[JavaLanguage]]): (List[ParameterTerm], List[ParameterTerm]) = { // TODO: if a required field has a default specified, include it in optionalTerms instead val (req, opt) = params.partition(_.term.getType match { - case cls: ClassOrInterfaceType => !isOptionalType(cls) + case cls: ClassOrInterfaceType => !cls.isOptional case _ => true }) @@ -70,9 +70,6 @@ object JacksonGenerator { }) } - private def isOptionalType(cls: ClassOrInterfaceType): Boolean = - (cls.getScope.asScala.fold("")(_.asString + ".") + cls.getName.asString) == "java.util.Optional" - private def lookupTypeName(tpeName: String, concreteTypes: List[PropMeta[JavaLanguage]])(f: Type => Target[Type]): Option[Target[Type]] = concreteTypes .find(_.clsName == tpeName) diff --git a/modules/codegen/src/main/scala/com/twilio/guardrail/generators/syntax/Java.scala b/modules/codegen/src/main/scala/com/twilio/guardrail/generators/syntax/Java.scala index d8556260ef..154a15d6cb 100644 --- a/modules/codegen/src/main/scala/com/twilio/guardrail/generators/syntax/Java.scala +++ b/modules/codegen/src/main/scala/com/twilio/guardrail/generators/syntax/Java.scala @@ -18,6 +18,35 @@ object Java { def asScala: Option[T] = if (o.isPresent) Option(o.get) else None } + implicit class RichType(val tpe: Type) extends AnyVal { + def isOptional: Boolean = + tpe match { + case cls: ClassOrInterfaceType => + val scope = cls.getScope.asScala + cls.getNameAsString == "Optional" && (scope.isEmpty || scope.map(_.asString).contains("java.util")) + case _ => false + } + + def containedType: Type = + tpe match { + case cls: ClassOrInterfaceType => cls.getTypeArguments.asScala.filter(_.size == 1).fold(tpe)(_.get(0)) + case _ => tpe + } + + def isNamed(name: String): Boolean = + tpe match { + case cls: ClassOrInterfaceType if name.contains(".") => (cls.getScope.asScala.fold("")(_ + ".") + cls.getNameAsString) == name + case cls: ClassOrInterfaceType => cls.getNameAsString == name + case _ => false + } + + def name: Option[String] = + tpe match { + case cls: ClassOrInterfaceType => Some(cls.getScope.asScala.fold("")(_ + ".") + cls.getNameAsString) + case _ => None + } + } + private[this] def safeParse[T](log: String)(parser: String => T, s: String)(implicit cls: ClassTag[T]): Target[T] = { Target.log.debug(log)(s) >> ( Try(parser(s)).toEither.fold(t => Target.raiseError(s"Unable to parse '${s}' to a ${cls.runtimeClass.getName}: ${t.getMessage}"), Target.pure)