diff --git a/dsl/common/src/main/kotlin/io/mockk/Answers.kt b/dsl/common/src/main/kotlin/io/mockk/Answers.kt index bce54c344..296ea21b0 100644 --- a/dsl/common/src/main/kotlin/io/mockk/Answers.kt +++ b/dsl/common/src/main/kotlin/io/mockk/Answers.kt @@ -27,12 +27,15 @@ data class FunctionAnswer(val answerFunc: (Call) -> T) : Answer { data class CoFunctionAnswer(val answerFunc: suspend (Call) -> T) : Answer { override fun answer(call: Call): T { - val continuation = call.invocation.args.lastOrNull() as? Continuation<*> - ?: throw MockKException("last parameter is not Continuation<*> for suspend call") - - return InternalPlatformDsl.coroutineCall { - answerFunc(call) - }.callWithContinuation(continuation) + val lastParam = call.invocation.args.lastOrNull() + return if (lastParam is Continuation<*>) + InternalPlatformDsl.coroutineCall { + answerFunc(call) + }.callWithContinuation(lastParam) + else + InternalPlatformDsl.runCoroutine { + answerFunc(call) + } } override suspend fun coAnswer(call: Call) = answerFunc(call) diff --git a/dsl/common/src/main/kotlin/io/mockk/InternalPlatformDsl.kt b/dsl/common/src/main/kotlin/io/mockk/InternalPlatformDsl.kt index fbf72a0ee..08f114650 100644 --- a/dsl/common/src/main/kotlin/io/mockk/InternalPlatformDsl.kt +++ b/dsl/common/src/main/kotlin/io/mockk/InternalPlatformDsl.kt @@ -1,7 +1,6 @@ package io.mockk import kotlin.coroutines.experimental.Continuation -import kotlin.reflect.KCallable expect object InternalPlatformDsl { fun identityHashCode(obj: Any): Int diff --git a/mockk/common/src/main/kotlin/io/mockk/impl/recording/ChainedCallDetector.kt b/mockk/common/src/main/kotlin/io/mockk/impl/recording/ChainedCallDetector.kt index 537c1e408..31028a056 100644 --- a/mockk/common/src/main/kotlin/io/mockk/impl/recording/ChainedCallDetector.kt +++ b/mockk/common/src/main/kotlin/io/mockk/impl/recording/ChainedCallDetector.kt @@ -6,7 +6,6 @@ import io.mockk.InternalPlatformDsl.toStr import io.mockk.impl.InternalPlatform import io.mockk.impl.log.Logger import io.mockk.impl.log.SafeToString -import kotlin.reflect.KClass class ChainedCallDetector(safeToString: SafeToString) { val log = safeToString(Logger()) diff --git a/mockk/common/src/test/kotlin/io/mockk/impl/recording/ChainedCallDetectorTest.kt b/mockk/common/src/test/kotlin/io/mockk/impl/recording/ChainedCallDetectorTest.kt index b13e0b0cf..b1b8c9c61 100644 --- a/mockk/common/src/test/kotlin/io/mockk/impl/recording/ChainedCallDetectorTest.kt +++ b/mockk/common/src/test/kotlin/io/mockk/impl/recording/ChainedCallDetectorTest.kt @@ -30,6 +30,9 @@ class ChainedCallDetectorTest { every { call1.method.varArgsArg } returns -1 every { call2.method.varArgsArg } returns -1 + every { call1.method.isSuspend } returns { false } + every { call2.method.isSuspend } returns { false } + detector.detect(listOf(callRound1, callRound2), 0, hashMapOf()) assertEquals("abc", detector.call.matcher.method.name) @@ -56,6 +59,9 @@ class ChainedCallDetectorTest { every { call1.method.varArgsArg } returns -1 every { call2.method.varArgsArg } returns -1 + every { call1.method.isSuspend } returns { false } + every { call2.method.isSuspend } returns { false } + detector.detect(listOf(callRound1, callRound2), 0, matcherMap) assertEquals("abc", detector.call.matcher.method.name) diff --git a/mockk/jvm/src/main/kotlin/io/mockk/impl/instantiation/JvmMockFactoryHelper.kt b/mockk/jvm/src/main/kotlin/io/mockk/impl/instantiation/JvmMockFactoryHelper.kt index d2cd16b00..a22c97747 100644 --- a/mockk/jvm/src/main/kotlin/io/mockk/impl/instantiation/JvmMockFactoryHelper.kt +++ b/mockk/jvm/src/main/kotlin/io/mockk/impl/instantiation/JvmMockFactoryHelper.kt @@ -87,17 +87,17 @@ object JvmMockFactoryHelper { private fun Method.toDescription(): MethodDescription { - val isSuspend: () -> Boolean = if ( - parameterTypes.lastOrNull()?.let { Continuation::class.java.isAssignableFrom(it) } == true - ) { - { - kotlinFunction?.isSuspend ?: false - } + val lastParam = parameterTypes.lastOrNull() + val isLastParamContinuation = lastParam?.let { + Continuation::class.java.isAssignableFrom(it) + } ?: false + + val isSuspend: () -> Boolean = if (isLastParamContinuation) { + { kotlinFunction?.isSuspend ?: false } } else { { false } } - return MethodDescription( name, returnType.kotlin,