diff --git a/modules/mockk-agent/api/mockk-agent.api b/modules/mockk-agent/api/mockk-agent.api index a29884c89..c9a8d6007 100644 --- a/modules/mockk-agent/api/mockk-agent.api +++ b/modules/mockk-agent/api/mockk-agent.api @@ -137,3 +137,10 @@ public class io/mockk/proxy/jvm/dispatcher/JvmMockKWeakMap : java/util/Map { public fun values ()Ljava/util/Collection; } +public final class io/mockk/proxy/jvm/util/DefaultInterfaceMethodResolver { +public static final field Companion Lio/mockk/proxy/jvm/util/DefaultInterfaceMethodResolver$Companion; +public fun ()V +} + +public final class io/mockk/proxy/jvm/util/DefaultInterfaceMethodResolver$Companion {} + diff --git a/modules/mockk-agent/src/jvmMain/java/io/mockk/proxy/jvm/advice/jvm/JvmMockKProxyInterceptor.java b/modules/mockk-agent/src/jvmMain/java/io/mockk/proxy/jvm/advice/jvm/JvmMockKProxyInterceptor.java index 1f937c30a..d68a7247f 100644 --- a/modules/mockk-agent/src/jvmMain/java/io/mockk/proxy/jvm/advice/jvm/JvmMockKProxyInterceptor.java +++ b/modules/mockk-agent/src/jvmMain/java/io/mockk/proxy/jvm/advice/jvm/JvmMockKProxyInterceptor.java @@ -1,13 +1,12 @@ package io.mockk.proxy.jvm.advice.jvm; -import io.mockk.proxy.MockKInvocationHandler; import io.mockk.proxy.jvm.advice.BaseAdvice; import io.mockk.proxy.jvm.advice.ProxyAdviceId; import io.mockk.proxy.jvm.dispatcher.JvmMockKDispatcher; +import io.mockk.proxy.jvm.util.DefaultInterfaceMethodResolver; import net.bytebuddy.implementation.bind.annotation.*; import java.lang.reflect.Method; -import java.util.Map; import java.util.concurrent.Callable; public class JvmMockKProxyInterceptor extends BaseAdvice { @@ -42,7 +41,8 @@ public static Object interceptNoSuper(@ProxyAdviceId long id, return null; } - return dispatcher.handle(self, method, args, null); + return dispatcher.handle(self, method, args, DefaultInterfaceMethodResolver.Companion.getDefaultImplementationOrNull$mockk_agent(self, method, args)); + } } \ No newline at end of file diff --git a/modules/mockk-agent/src/jvmMain/kotlin/io/mockk/proxy/jvm/util/DefaultInterfaceMethodResolver.kt b/modules/mockk-agent/src/jvmMain/kotlin/io/mockk/proxy/jvm/util/DefaultInterfaceMethodResolver.kt new file mode 100644 index 000000000..e95a038bf --- /dev/null +++ b/modules/mockk-agent/src/jvmMain/kotlin/io/mockk/proxy/jvm/util/DefaultInterfaceMethodResolver.kt @@ -0,0 +1,39 @@ +package io.mockk.proxy.jvm.util + +import io.mockk.proxy.jvm.advice.MethodCall +import java.lang.reflect.Method +import java.lang.reflect.Modifier + +class DefaultInterfaceMethodResolver { + + companion object { + + internal fun getDefaultImplementationOrNull(mock: Any, method: Method, arguments: Array): MethodCall? = + findDefaultImplMethod(method) + ?.let { + val defaultImplMethodArguments = arrayOf(mock, *arguments) + MethodCall(mock, it, defaultImplMethodArguments) + } + + private fun findDefaultImplMethod(method: Method): Method? = + method.takeIf { Modifier.isAbstract(it.modifiers) } + ?.declaringClass + ?.let { declaringClass -> + findDefaultImplsClass(declaringClass) + ?.runCatching { + getMethod(method.name, declaringClass, *method.parameterTypes.requireNoNulls()) + } + ?.getOrNull() + ?.takeIf { Modifier.isStatic(it.modifiers) } + } + + private fun findDefaultImplsClass(clazz: Class<*>): Class<*>? = + clazz.takeIf { it.isInterface && isKotlinClass(it) } + ?.classes?.firstOrNull { it.simpleName == "DefaultImpls" && Modifier.isStatic(it.modifiers) } + + private fun isKotlinClass(clazz: Class<*>): Boolean { + return clazz.isAnnotationPresent(Metadata::class.java) + } + } + +} diff --git a/modules/mockk-agent/src/jvmTest/java/io/mockk/proxy/util/DefaultInterfaceMethodResolverTest.kt b/modules/mockk-agent/src/jvmTest/java/io/mockk/proxy/util/DefaultInterfaceMethodResolverTest.kt new file mode 100644 index 000000000..ca5681ec6 --- /dev/null +++ b/modules/mockk-agent/src/jvmTest/java/io/mockk/proxy/util/DefaultInterfaceMethodResolverTest.kt @@ -0,0 +1,69 @@ +package io.mockk.proxy.util + +import io.mockk.proxy.jvm.util.DefaultInterfaceMethodResolver +import org.junit.jupiter.api.Assertions.assertNotNull +import org.junit.jupiter.api.Assertions.assertNull +import kotlin.test.Test + + +class DefaultInterfaceMethodResolverTest { + + interface A { + fun method() + fun defaultMethod(arg: String): String { + return "Arg: $arg" + } + } + + class B : A { + fun subclassMethod(arg: String) {} + override fun method() { + } + } + + @Test + fun `should return MethodCall when default implementation exists`() { + val subclass = B() + val method = A::class.java.getMethod("defaultMethod", String::class.java) + val arguments = arrayOfNulls(1).also { it[0] = "arg" } + + val result = DefaultInterfaceMethodResolver.getDefaultImplementationOrNull(subclass, method, arguments) + + assertNotNull(result) + + } + + @Test + fun `should return null when is concrete class method`() { + val subclass = B() + val method = B::class.java.getMethod("subclassMethod", String::class.java) + val arguments = arrayOfNulls(1).also { it[0] = "arg" } + + val result = DefaultInterfaceMethodResolver.getDefaultImplementationOrNull(subclass, method, arguments) + + assertNull(result) + } + + @Test + fun `should return null when method is overwritten`() { + val subclass = B() + val method = A::class.java.getDeclaredMethod("method") + val arguments = arrayOfNulls(0) + + val result = DefaultInterfaceMethodResolver.getDefaultImplementationOrNull(subclass, method, arguments) + + assertNull(result) + } + + @Test + fun `should return null when method is not a Kotlin class`() { + val subclass = ArrayList() + val method = ArrayList::class.java.getDeclaredMethod("add", Any::class.java) + val arguments = arrayOfNulls(1).also { it[0] = "element" } + + val result = DefaultInterfaceMethodResolver.getDefaultImplementationOrNull(subclass, method, arguments) + + assertNull(result) + } + +} \ No newline at end of file diff --git a/modules/mockk/src/jvmTest/kotlin/io/mockk/it/CallOriginalOnDefaultInterfaceMethodTest.kt b/modules/mockk/src/jvmTest/kotlin/io/mockk/it/CallOriginalOnDefaultInterfaceMethodTest.kt new file mode 100644 index 000000000..59acbf2d4 --- /dev/null +++ b/modules/mockk/src/jvmTest/kotlin/io/mockk/it/CallOriginalOnDefaultInterfaceMethodTest.kt @@ -0,0 +1,42 @@ +package io.mockk.it + +import io.mockk.* +import kotlin.test.Test + +class CallOriginalOnDefaultInterfaceMethodTest { + + interface A { + fun method1(items: List) + fun method2(items: List) + fun defaultMethod(callMethod2: Boolean) { + method1(listOf(1, 2, 3)) + if (callMethod2) + method2(listOf(4, 5, 6)) + } + } + + @Test + fun `should call the original default method when spy the class`() { + val spy = spyk() + every { spy.defaultMethod(any()) } answers { callOriginal() } + + spy.defaultMethod(callMethod2 = true) + + verify { spy.method1(listOf(1, 2, 3)) } + verify { spy.method2(listOf(4, 5, 6)) } + } + + @Test + fun `should call the original default method when mock the class`() { + val mock = mockk() + every { mock.defaultMethod(any()) } answers { callOriginal() } + every { mock.method1(any()) } just runs + every { mock.method2(any()) } just runs + + mock.defaultMethod(callMethod2 = true) + + verify { mock.method1(listOf(1, 2, 3)) } + verify { mock.method2(listOf(4, 5, 6)) } + } + +}