Skip to content

Commit

Permalink
Use CodeBlock.Builder to construct routing code
Browse files Browse the repository at this point in the history
Removes the need for "global imports" and improves the generated code by avoiding unused imports
  • Loading branch information
ulrikandersen committed Apr 22, 2024
1 parent c8e1360 commit 916f6b8
Show file tree
Hide file tree
Showing 9 changed files with 97 additions and 84 deletions.
8 changes: 2 additions & 6 deletions src/main/kotlin/com/cjbooms/fabrikt/cli/CodeGenerator.kt
Original file line number Diff line number Diff line change
Expand Up @@ -94,13 +94,9 @@ class CodeGenerator(

val controllerFiles: Collection<FileSpec> = generator.generate().files
val libFiles: Collection<FileSpec> = generator.generateLibrary().map {
val builder = FileSpec.builder(it.className)
FileSpec.builder(it.className)
.addType(it.spec)

// add imports for the library files
it.imports.forEach { import -> builder.addImport(import.packageName, import.name) }

builder.build()
.build()
}

return controllerFiles.plus(libFiles)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@ import com.cjbooms.fabrikt.model.BodyParameter
import com.cjbooms.fabrikt.model.ControllerLibraryType
import com.cjbooms.fabrikt.model.ControllerType
import com.cjbooms.fabrikt.model.HeaderParam
import com.cjbooms.fabrikt.model.Import
import com.cjbooms.fabrikt.model.IncomingParameter
import com.cjbooms.fabrikt.model.KotlinTypes
import com.cjbooms.fabrikt.model.PathParam
Expand All @@ -25,9 +24,11 @@ import com.cjbooms.fabrikt.util.toUpperCase
import com.reprezen.kaizen.oasparser.model3.Operation
import com.reprezen.kaizen.oasparser.model3.Path
import com.squareup.kotlinpoet.ClassName
import com.squareup.kotlinpoet.CodeBlock
import com.squareup.kotlinpoet.FileSpec
import com.squareup.kotlinpoet.FunSpec
import com.squareup.kotlinpoet.KModifier
import com.squareup.kotlinpoet.MemberName
import com.squareup.kotlinpoet.ParameterSpec
import com.squareup.kotlinpoet.ParameterizedTypeName.Companion.parameterizedBy
import com.squareup.kotlinpoet.PropertySpec
Expand Down Expand Up @@ -129,30 +130,13 @@ class KtorControllerInterfaceGenerator(
controllerBuilder
}

// imports for the generated code
val globalImports = listOf( // imports necessary for ktor
"io.ktor.server.application" to "call",
"io.ktor.server.application" to "ApplicationCall",
"io.ktor.server.auth" to "authenticate",
"io.ktor.server.plugins" to "BadRequestException",
"io.ktor.server.request" to "receive",
"io.ktor.server.response" to "respond",
"io.ktor.server.util" to "getOrFail",
"io.ktor.util.reflect" to "typeInfo",
"io.ktor.util.converters" to "DefaultConversionService",
"io.ktor.server.plugins" to "ParameterConversionException",
)
.plus(usedVerbs.map { "io.ktor.server.routing" to it })
.map { Import(it.first, it.second) }
.toSet()

val controllers = controllerInterfaces.map { ControllerType(it.build(), packages.base) }.toSet()

return KtorControllers(controllers, globalImports)
return KtorControllers(controllers)
}

private fun buildRouteCode(operation: Operation, verb: String, path: Map.Entry<String, Path>): String {
val codeBlock = StringBuilder()
private fun buildRouteCode(operation: Operation, verb: String, path: Map.Entry<String, Path>): CodeBlock {
val builder = CodeBlock.builder()

val globalSecurity = this.api.openApi3.securityRequirements.securitySupport()
val securityOption = operation.securitySupport(globalSecurity)
Expand All @@ -167,11 +151,13 @@ class KtorControllerInterfaceGenerator(
this.api.openApi3.securityRequirements.first().requirements.keys.first()
}

if (securityOption == SecuritySupport.AUTHENTICATION_OPTIONAL) {
codeBlock.appendLine("authenticate(\"$authName\", optional = true) {⇥")
} else {
codeBlock.appendLine("authenticate(\"$authName\") {⇥")
}
builder
.addStatement(
"%M(\"$authName\", optional = %L) {",
MemberName("io.ktor.server.auth", "authenticate"),
securityOption == SecuritySupport.AUTHENTICATION_OPTIONAL,
)
.indent()
}

val params = operation.toIncomingParameters(packages.base, path.value.parameters, emptyList())
Expand All @@ -181,31 +167,58 @@ class KtorControllerInterfaceGenerator(

val methodName = getMethodName(operation, verb, path)

codeBlock.appendLine("${verb}(\"${path.key}\") {⇥")
builder
.addStatement(
"%M(%S) {",
MemberName("io.ktor.server.routing", verb),
path.key,
)
.indent()

pathParams.forEach { param ->
codeBlock.appendLine("val ${param.name} = call.parameters.getOrFail<${param.type}>(\"${param.originalName}\")")
builder.addStatement(
"val ${param.name} = %M.parameters.%M<${param.type}>(\"${param.originalName}\")",
MemberName("io.ktor.server.application", "call"),
MemberName("io.ktor.server.util", "getOrFail", isExtension = true),
)
}

headerParams.forEach { param ->
if (param.isRequired) {
codeBlock.appendLine("val ${param.name} = call.request.headers.getOrFail(\"${param.originalName}\")")
builder.addStatement(
"val ${param.name} = %M.request.headers.getOrFail(\"${param.originalName}\")",
MemberName("io.ktor.server.application", "call"),
)
} else {
codeBlock.appendLine("val ${param.name} = call.request.headers[\"${param.originalName}\"]")
builder.addStatement(
"val ${param.name} = %M.request.headers[\"${param.originalName}\"]",
MemberName("io.ktor.server.application", "call"),
)
}
}

queryParams.forEach { param ->
val typeName = param.type.copy(nullable = false) // not nullable because we handle that in the queryParameters.get* below
if (param.isRequired) {
codeBlock.appendLine("val ${param.name} = call.request.queryParameters.getOrFail<$typeName>(\"${param.originalName}\")")
builder.addStatement(
"val ${param.name} = %M.request.queryParameters.%M<$typeName>(\"${param.originalName}\")",
MemberName("io.ktor.server.application", "call"),
MemberName("io.ktor.server.util", "getOrFail", isExtension = true),
)
} else {
codeBlock.appendLine("val ${param.name} = call.request.queryParameters.getTyped<$typeName>(\"${param.originalName}\")")
builder.addStatement(
"val ${param.name} = %M.request.queryParameters.getTyped<$typeName>(\"${param.originalName}\")",
MemberName("io.ktor.server.application", "call"),
)
}
}

bodyParams.forEach { param ->
codeBlock.appendLine("val ${param.name} = call.receive<${param.type.simpleName()}>()")
builder.addStatement(
"val ${param.name} = %M.%M<${param.type.simpleName()}>()",
MemberName("io.ktor.server.application", "call"),
MemberName("io.ktor.server.request", "receive"),
)
}

val methodParameters =
Expand All @@ -216,24 +229,36 @@ class KtorControllerInterfaceGenerator(
else "call"
}

codeBlock.appendLine("val result = controller.$methodName($methodParameters)")
builder.addStatement("val result = controller.$methodName($methodParameters)")

if (happyPathResponse.simpleName() == Unit::class.simpleName) {
// When return type is Unit we only respond with the status code.
// Note however that in some cases the controller may choose to respond directly on the call
// in which case that takes precedence in Ktor's routing. For example: call.respondRedirect().
codeBlock.appendLine("call.respond(result.status)")
builder.addStatement(
"%M.%M(result.status)",
MemberName("io.ktor.server.application", "call"),
MemberName("io.ktor.server.response", "respond"),
)
} else {
codeBlock.appendLine("call.respond(result.status, result.message)")
builder.addStatement(
"%M.%M(result.status, result.message)",
MemberName("io.ktor.server.application", "call"),
MemberName("io.ktor.server.response", "respond"),
)
}

codeBlock.appendLine("⇤}")
builder
.unindent()
.addStatement("}")

if (addAuth) {
codeBlock.appendLine("⇤}")
builder
.unindent()
.addStatement("}")
}

return codeBlock.toString()
return builder.build()
}

override fun generateLibrary(): Collection<ControllerLibraryType> = listOf(
Expand Down Expand Up @@ -287,14 +312,17 @@ class KtorControllerInterfaceGenerator(
.returns(returnType)
.addCode("""
val values = getAll(name) ?: return null
val typeInfo = typeInfo<R>()
val typeInfo = %M<R>()
return try {
@Suppress("UNCHECKED_CAST")
DefaultConversionService.fromValues(values, typeInfo) as R
%M.fromValues(values, typeInfo) as R
} catch (cause: Exception) {
throw ParameterConversionException(name, typeInfo.type.simpleName ?: typeInfo.type.toString(), cause)
throw %M(name, typeInfo.type.simpleName ?: typeInfo.type.toString(), cause)
}
""".trimIndent()
""".trimIndent(),
MemberName("io.ktor.util.reflect", "typeInfo"),
MemberName("io.ktor.util.converters", "DefaultConversionService",),
MemberName("io.ktor.server.plugins", "ParameterConversionException")
)
.addKdoc("""
Gets parameter value associated with this name or null if the name is not present.
Expand All @@ -316,8 +344,8 @@ class KtorControllerInterfaceGenerator(
.returns(String::class)
.addParameter("name", String::class)
.addCode("""
return this[name] ?: throw BadRequestException("Header " + name + " is required")
""".trimIndent())
return this[name] ?: throw %M("Header " + name + " is required")
""".trimIndent(), MemberName("io.ktor.server.plugins", "BadRequestException"))
.addKdoc("""
Gets first value from the list of values associated with a name.
Expand All @@ -339,13 +367,9 @@ class KtorControllerInterfaceGenerator(

data class KtorControllers(
val controllers: Set<ControllerType>,
val globalImports: Set<Import>,
) : KotlinTypes(controllers) {
override val files: Collection<FileSpec> = super.files.map { fileSpec ->
fileSpec.toBuilder().apply {
// Could be improved by only importing what is necessary in the specific file
globalImports.forEach { addImport(it.packageName, it.name) }
}.build()
fileSpec.toBuilder().build()
}
}
}
Expand Down
4 changes: 1 addition & 3 deletions src/main/kotlin/com/cjbooms/fabrikt/model/KotlinGenModels.kt
Original file line number Diff line number Diff line change
Expand Up @@ -41,11 +41,9 @@ class ControllerType(spec: TypeSpec, basePackage: String) : GeneratedType(spec,
}
}

class ControllerLibraryType(spec: TypeSpec, basePackage: String, val imports: Collection<Import> = emptySet()) :
class ControllerLibraryType(spec: TypeSpec, basePackage: String) :
GeneratedType(spec, controllersPackage(basePackage))

data class Import(val packageName: String, val name: String)

data class Models(val models: Collection<ModelType>) : KotlinTypes(models) {
override val files: Collection<FileSpec> = models.toFileSpec()
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -180,17 +180,12 @@ class KtorControllerInterfaceGeneratorTest {
val destPackage = if (controllers.isNotEmpty()) controllers.first().destinationPackage else ""
val singleFileBuilder = FileSpec.builder(destPackage, "dummyFilename")

globalImports.forEach {
singleFileBuilder.addImport(it.packageName, it.name)
}

controllers.forEach {
singleFileBuilder.addType(it.spec)
}

extraSpecs.forEach {
singleFileBuilder.addType(it.spec)
it.imports.forEach { import -> singleFileBuilder.addImport(import.packageName, import.name) }
}

val singleFileString = singleFileBuilder.build().toString()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,8 @@ public interface RequiredController {

public companion object {
public fun Route.requiredRoutes(controller: RequiredController) {
authenticate("BasicAuth") {
get("/required") {
authenticate("BasicAuth", optional = false) {
`get`("/required") {
val testString = call.request.queryParameters.getOrFail<kotlin.String>("testString")
val result = controller.testPath(call, testString)
call.respond(result.status)
Expand Down Expand Up @@ -81,7 +81,7 @@ public interface ProhibitedController {

public companion object {
public fun Route.prohibitedRoutes(controller: ProhibitedController) {
get("/prohibited") {
`get`("/prohibited") {
val testString = call.request.queryParameters.getOrFail<kotlin.String>("testString")
val result = controller.testPath(call, testString)
call.respond(result.status)
Expand Down Expand Up @@ -133,7 +133,7 @@ public interface OptionalController {
public companion object {
public fun Route.optionalRoutes(controller: OptionalController) {
authenticate("BasicAuth", optional = true) {
get("/optional") {
`get`("/optional") {
val testString = call.request.queryParameters.getOrFail<kotlin.String>("testString")
val result = controller.testPath(call, testString)
call.respond(result.status)
Expand Down Expand Up @@ -185,7 +185,7 @@ public interface NoneController {

public companion object {
public fun Route.noneRoutes(controller: NoneController) {
get("/none") {
`get`("/none") {
val testString = call.request.queryParameters.getOrFail<kotlin.String>("testString")
val result = controller.testPath(call, testString)
call.respond(result.status)
Expand Down Expand Up @@ -236,8 +236,8 @@ public interface DefaultController {

public companion object {
public fun Route.defaultRoutes(controller: DefaultController) {
authenticate("basicAuth") {
get("/default") {
authenticate("basicAuth", optional = false) {
`get`("/default") {
val testString = call.request.queryParameters.getOrFail<kotlin.String>("testString")
val result = controller.testPath(call, testString)
call.respond(result.status)
Expand Down
Loading

0 comments on commit 916f6b8

Please sign in to comment.