diff --git a/detekt-rules-style/src/main/kotlin/io/gitlab/arturbosch/detekt/rules/style/optional/OptionalUnit.kt b/detekt-rules-style/src/main/kotlin/io/gitlab/arturbosch/detekt/rules/style/optional/OptionalUnit.kt index 5771771e40b..1218bb3e7b5 100644 --- a/detekt-rules-style/src/main/kotlin/io/gitlab/arturbosch/detekt/rules/style/optional/OptionalUnit.kt +++ b/detekt-rules-style/src/main/kotlin/io/gitlab/arturbosch/detekt/rules/style/optional/OptionalUnit.kt @@ -9,17 +9,21 @@ import io.gitlab.arturbosch.detekt.api.Rule import io.gitlab.arturbosch.detekt.api.Severity import io.gitlab.arturbosch.detekt.rules.isOverride import org.jetbrains.kotlin.cfg.WhenChecker +import org.jetbrains.kotlin.js.translate.callTranslator.getReturnType import org.jetbrains.kotlin.psi.KtBlockExpression import org.jetbrains.kotlin.psi.KtExpression import org.jetbrains.kotlin.psi.KtIfExpression import org.jetbrains.kotlin.psi.KtNameReferenceExpression import org.jetbrains.kotlin.psi.KtNamedFunction +import org.jetbrains.kotlin.psi.KtTypeReference import org.jetbrains.kotlin.psi.KtWhenExpression import org.jetbrains.kotlin.psi.psiUtil.siblings import org.jetbrains.kotlin.resolve.BindingContext import org.jetbrains.kotlin.resolve.bindingContextUtil.isUsedAsExpression +import org.jetbrains.kotlin.resolve.calls.callUtil.getResolvedCall import org.jetbrains.kotlin.resolve.calls.callUtil.getType import org.jetbrains.kotlin.types.typeUtil.isNothing +import org.jetbrains.kotlin.types.typeUtil.isTypeParameter import org.jetbrains.kotlin.types.typeUtil.isUnit import org.jetbrains.kotlin.utils.addToStdlib.firstIsInstanceOrNull @@ -56,8 +60,9 @@ class OptionalUnit(config: Config = Config.empty) : Rule(config) { ) override fun visitNamedFunction(function: KtNamedFunction) { - if (function.hasDeclaredReturnType()) { - checkFunctionWithExplicitReturnType(function) + val typeReference = function.typeReference + if (typeReference != null) { + checkFunctionWithExplicitReturnType(function, typeReference) } else if (!function.isOverride()) { checkFunctionWithInferredReturnType(function) } @@ -105,11 +110,11 @@ class OptionalUnit(config: Config = Config.empty) : Rule(config) { } } - private fun checkFunctionWithExplicitReturnType(function: KtNamedFunction) { - val typeReference = function.typeReference - val typeElementText = typeReference?.typeElement?.text + private fun checkFunctionWithExplicitReturnType(function: KtNamedFunction, typeReference: KtTypeReference) { + val typeElementText = typeReference.typeElement?.text if (typeElementText == UNIT) { - if (function.initializer.isNothingType()) return + val initializer = function.initializer + if (initializer?.isGenericOrNothingType() == true) return report(CodeSmell(issue, Entity.from(typeReference), createMessage(function))) } } @@ -124,8 +129,14 @@ class OptionalUnit(config: Config = Config.empty) : Rule(config) { private fun createMessage(function: KtNamedFunction) = "The function ${function.name} " + "defines a return type of Unit. This is unnecessary and can safely be removed." - private fun KtExpression?.isNothingType() = - bindingContext != BindingContext.EMPTY && this?.getType(bindingContext)?.isNothing() == true + private fun KtExpression.isGenericOrNothingType(): Boolean { + if (bindingContext == BindingContext.EMPTY) return false + val isGenericType = getResolvedCall(bindingContext)?.getReturnType()?.isTypeParameter() == true + val isNothingType = getType(bindingContext)?.isNothing() == true + // Either the function initializer returns Nothing or it is a generic function + // into which Unit is passed, but not both. + return (isGenericType && !isNothingType) || (isNothingType && !isGenericType) + } companion object { private const val UNIT = "Unit" diff --git a/detekt-rules-style/src/test/kotlin/io/gitlab/arturbosch/detekt/rules/style/optional/OptionalUnitSpec.kt b/detekt-rules-style/src/test/kotlin/io/gitlab/arturbosch/detekt/rules/style/optional/OptionalUnitSpec.kt index 3fe49e0ab9a..edaa59ecfee 100644 --- a/detekt-rules-style/src/test/kotlin/io/gitlab/arturbosch/detekt/rules/style/optional/OptionalUnitSpec.kt +++ b/detekt-rules-style/src/test/kotlin/io/gitlab/arturbosch/detekt/rules/style/optional/OptionalUnitSpec.kt @@ -18,6 +18,21 @@ class OptionalUnitSpec : Spek({ val env: KotlinCoreEnvironment by memoized() describe("OptionalUnit rule") { + it("should report when a function has an explicit Unit return type with context") { + val code = """ + fun foo(): Unit { } + """.trimIndent() + val findings = subject.compileAndLintWithContext(env, code) + assertThat(findings).hasSize(1) + } + + it("should not report when a function has a non-unit body expression") { + val code = """ + fun foo() = String + """.trimIndent() + val findings = subject.compileAndLintWithContext(env, code) + assertThat(findings).isEmpty() + } context("several functions which return Unit") { @@ -296,14 +311,72 @@ class OptionalUnitSpec : Spek({ val findings = subject.compileAndLintWithContext(env, code) assertThat(findings).hasSize(1) } + + it("another object is used as the last expression") { + val code = """ + fun foo() { + String + } + """.trimIndent() + val findings = subject.compileAndLintWithContext(env, code) + assertThat(findings).isEmpty() + } } - it("should not report when function initializer is Nothing") { - val code = """ + context("function initializers") { + it("should not report when function initializer is Nothing") { + val code = """ fun test(): Unit = throw UnsupportedOperationException() """ - val findings = subject.compileAndLintWithContext(env, code) - assertThat(findings).isEmpty() + val findings = subject.compileAndLintWithContext(env, code) + assertThat(findings).isEmpty() + } + + it("should not report when the function initializer requires a type") { + val code = """ + fun foo(block: (List) -> Unit): T { + val list = listOf() + block(list) + return list.first() + } + + fun doFoo(): Unit = foo {} + """.trimIndent() + val findings = subject.compileAndLintWithContext(env, code) + assertThat(findings).isEmpty() + } + + it("should report on function initializers when there is no context") { + val code = """ + fun test(): Unit = throw UnsupportedOperationException() + """ + val findings = subject.compileAndLint(code) + assertThat(findings).hasSize(1) + } + + it("should report when the function initializer takes in the type Nothing") { + val code = """ + fun foo(block: (List) -> Unit): T { + val list = listOf() + block(list) + return list.first() + } + + fun doFoo(): Unit = foo {} + """.trimIndent() + val findings = subject.compileAndLintWithContext(env, code) + assertThat(findings).hasSize(1) + } + + it("should report when the function initializer does not provide a different type") { + val code = """ + fun foo() {} + + fun doFoo(): Unit = foo() + """.trimIndent() + val findings = subject.compileAndLintWithContext(env, code) + assertThat(findings).hasSize(1) + } } } })