Skip to content

Commit

Permalink
isMockKMock + return set of mocks #221
Browse files Browse the repository at this point in the history
  • Loading branch information
oleksiyp committed Feb 9, 2019
1 parent 510ed81 commit 5efe6b6
Show file tree
Hide file tree
Showing 16 changed files with 295 additions and 31 deletions.
31 changes: 28 additions & 3 deletions dsl/common/src/main/kotlin/io/mockk/API.kt
Expand Up @@ -588,6 +588,29 @@ object MockKDsl {
implementation.constructorMockFactory.clearAll(options)
}
}

/*
* Checks if provided mock is mock of certain type
*/
fun internalIsMockKMock(
mock: Any,
regular: Boolean = true,
spy: Boolean = false,
objectMock: Boolean = false,
staticMock: Boolean = false,
constructorMock: Boolean = false
): Boolean {
val typeChecker = MockKGateway.implementation().mockTypeChecker

return when {
regular && typeChecker.isRegularMock(mock) -> true
spy && typeChecker.isSpy(mock) -> true
objectMock && typeChecker.isObjectMock(mock) -> true
staticMock && typeChecker.isStaticMock(mock) -> true
constructorMock && typeChecker.isConstructorMock(mock) -> true
else -> false
}
}
}

/**
Expand Down Expand Up @@ -651,9 +674,11 @@ open class MockKMatcherScope(

inline fun <reified T : Any> eq(value: T, inverse: Boolean = false): T =
match(EqMatcher(value, inverse = inverse))
inline fun <reified T: Any> neq(value: T): T = eq(value, true)

inline fun <reified T : Any> neq(value: T): T = eq(value, true)
inline fun <reified T : Any> refEq(value: T, inverse: Boolean = false): T =
match(EqMatcher(value, ref = true, inverse = inverse))

inline fun <reified T : Any> nrefEq(value: T) = refEq(value, true)

inline fun <reified T : Any> any(): T = match(ConstantMatcher(true))
Expand Down Expand Up @@ -2149,13 +2174,13 @@ class MockKAnswerScope<T, B>(
var fieldValueAny: Any?
set(value) {
val fv = backingFieldValue
?: throw MockKException("no backing field found for '${call.invocation.method.name}'")
?: throw MockKException("no backing field found for '${call.invocation.method.name}'")

fv.setter(value)
}
get() {
val fv = backingFieldValue
?: throw MockKException("no backing field found for '${call.invocation.method.name}'")
?: throw MockKException("no backing field found for '${call.invocation.method.name}'")
return fv.getter()
}
}
Expand Down
13 changes: 13 additions & 0 deletions dsl/common/src/main/kotlin/io/mockk/GatewayAPI.kt
Expand Up @@ -18,6 +18,7 @@ interface MockKGateway {
val clearer: Clearer
val mockInitializer: MockInitializer
val verificationAcknowledger: VerificationAcknowledger
val mockTypeChecker: MockTypeChecker

fun verifier(params: VerificationParameters): CallVerifier

Expand Down Expand Up @@ -247,4 +248,16 @@ interface MockKGateway {

fun acknowledgeVerified(mock: Any)
}

interface MockTypeChecker {
fun isRegularMock(mock: Any): Boolean

fun isSpy(mock: Any): Boolean

fun isObjectMock(mock: Any): Boolean

fun isStaticMock(mock: Any): Boolean

fun isConstructorMock(mock: Any): Boolean
}
}
22 changes: 22 additions & 0 deletions mockk/common/src/main/kotlin/io/mockk/MockK.kt
Expand Up @@ -650,6 +650,28 @@ inline fun clearAllMocks(
)
}

/**
* Checks if provided mock is mock of certain type
*/
fun isMockKMock(
mock: Any,
regular: Boolean = true,
spy: Boolean = false,
objectMock: Boolean = false,
staticMock: Boolean = false,
constructorMock: Boolean = false
) = MockK.useImpl {
MockKDsl.internalIsMockKMock(
mock,
regular,
spy,
objectMock,
staticMock,
constructorMock
)
}


object MockKAnnotations {
/**
* Initializes properties annotated with @MockK, @RelaxedMockK, @Slot and @SpyK in provided object.
Expand Down
Expand Up @@ -45,7 +45,8 @@ abstract class AbstractMockFactory(
relaxed || MockKSettings.relaxed,
relaxUnitFun || MockKSettings.relaxUnitFun,
gatewayAccess,
true
true,
MockType.REGULAR
)

if (moreInterfaces.isEmpty()) {
Expand Down Expand Up @@ -90,7 +91,8 @@ abstract class AbstractMockFactory(
actualCls,
newName,
gatewayAccess,
recordPrivateCalls || MockKSettings.recordPrivateCalls
recordPrivateCalls || MockKSettings.recordPrivateCalls,
MockType.SPY
)

val useDefaultConstructor = objToCopy == null
Expand All @@ -116,7 +118,8 @@ abstract class AbstractMockFactory(
mockType,
"temporary mock",
gatewayAccess = gatewayAccess,
recordPrivateCalls = true
recordPrivateCalls = true,
mockType = MockType.TEMPORARY
)

log.trace { "Building proxy for ${mockType.toStr()} hashcode=${InternalPlatform.hkd(mockType)}" }
Expand Down
@@ -1,15 +1,15 @@
package io.mockk.impl.recording.states

import io.mockk.Answer
import io.mockk.ConstantAnswer
import io.mockk.InternalPlatformDsl
import io.mockk.InvocationMatcher
import io.mockk.*
import io.mockk.InternalPlatformDsl.toStr
import io.mockk.impl.log.Logger
import io.mockk.impl.recording.CommonCallRecorder
import io.mockk.impl.stub.AdditionalAnswerOpportunity
import io.mockk.impl.stub.MockKStub

class StubbingAwaitingAnswerState(recorder: CommonCallRecorder) : CallRecordingState(recorder) {
val log = recorder.safeToString(Logger<StubbingAwaitingAnswerState>())

override fun answer(answer: Answer<*>) {
val calls = recorder.calls

Expand Down Expand Up @@ -43,6 +43,18 @@ class StubbingAwaitingAnswerState(recorder: CommonCallRecorder) : CallRecordingS
recorder.state = recorder.factories.answeringStillAcceptingAnswersState(recorder, answerOpportunity!!)
}

override fun call(invocation: Invocation): Any? {
val stub = recorder.stubRepo.stubFor(invocation.self)
try {
val answer = stub.answer(invocation)
log.debug { "Answering(await answering state) ${answer.toStr()} on $invocation" }
return answer
} catch (ex: Exception) {
log.debug { "Throwing(await answering state) ${ex.toStr()} on $invocation" }
throw ex
}
}

private fun assignFieldIfMockingProperty(mock: Any, matcher: InvocationMatcher, ans: Answer<Any?>) {
try {
if (ans !is ConstantAnswer) {
Expand All @@ -63,8 +75,4 @@ class StubbingAwaitingAnswerState(recorder: CommonCallRecorder) : CallRecordingS
}

private fun String.toCamelCase() = if (isEmpty()) this else substring(0, 1).toLowerCase() + substring(1)

companion object {
val log = Logger<StubbingAwaitingAnswerState>()
}
}
3 changes: 2 additions & 1 deletion mockk/common/src/main/kotlin/io/mockk/impl/stub/MockKStub.kt
Expand Up @@ -12,7 +12,8 @@ open class MockKStub(
val relaxed: Boolean = false,
val relaxUnitFun: Boolean = false,
val gatewayAccess: StubGatewayAccess,
val recordPrivateCalls: Boolean
val recordPrivateCalls: Boolean,
val mockType: MockType
) : Stub {
val log = gatewayAccess.safeToString(Logger<MockKStub>())

Expand Down
10 changes: 10 additions & 0 deletions mockk/common/src/main/kotlin/io/mockk/impl/stub/MockType.kt
@@ -0,0 +1,10 @@
package io.mockk.impl.stub

enum class MockType {
REGULAR,
SPY,
OBJECT,
STATIC,
CONSTRUCTOR,
TEMPORARY
}
5 changes: 3 additions & 2 deletions mockk/common/src/main/kotlin/io/mockk/impl/stub/SpyKStub.kt
Expand Up @@ -7,8 +7,9 @@ class SpyKStub<T : Any>(
cls: KClass<out T>,
name: String,
gatewayAccess: StubGatewayAccess,
recordPrivateCalls: Boolean
) : MockKStub(cls, name, false, false, gatewayAccess, recordPrivateCalls) {
recordPrivateCalls: Boolean,
mockType: MockType
) : MockKStub(cls, name, false, false, gatewayAccess, recordPrivateCalls, mockType) {

override fun defaultAnswer(invocation: Invocation): Any? {
return invocation.originalCall()
Expand Down
Expand Up @@ -2,17 +2,18 @@ package io.mockk.impl.stub

import io.mockk.MockKException
import io.mockk.impl.InternalPlatform
import io.mockk.impl.MultiNotifier
import io.mockk.impl.MultiNotifier.Session
import io.mockk.impl.WeakRef
import io.mockk.impl.log.SafeToString

class StubRepository(val safeToString: SafeToString) {
class StubRepository(
val safeToString: SafeToString
) {
private val stubs = InternalPlatform.weakMap<Any, WeakRef>()
private val recordCallMultiNotifier = InternalPlatform.multiNotifier()

fun stubFor(mock: Any): Stub = get(mock)
?: throw MockKException(safeToString.exec { "can't find stub $mock" })
?: throw MockKException(safeToString.exec { "can't find stub $mock" })

fun add(mock: Any, stub: Stub) {
stubs[mock] = InternalPlatform.weakRef(stub)
Expand Down
39 changes: 39 additions & 0 deletions mockk/common/src/test/kotlin/io/mockk/gh/Issue221Test.kt
@@ -0,0 +1,39 @@
package io.mockk.gh

import io.mockk.every
import io.mockk.isMockKMock
import io.mockk.mockk
import kotlin.test.Test
import kotlin.test.assertEquals
import kotlin.test.assertTrue

class Issue221Test {
interface Foo {
fun getTasks(): Set<Task>

fun getTask(): Task
}

interface Task {
fun getSubTask(): Task

fun doIt(): Int
}


@Test
fun returnsSetOfMocks() {
val foo = mockk<Foo>()
every { foo.getTasks() } returns setOf(mockk(), mockk())

val tasks = foo.getTasks()

assertEquals(2, tasks.size)
val task1 = tasks.first()
val task2 = tasks.drop(1).first()

assertTrue(isMockKMock(task1))
assertTrue(isMockKMock(task2))
}

}
68 changes: 68 additions & 0 deletions mockk/common/src/test/kotlin/io/mockk/it/MockTypesTest.kt
@@ -0,0 +1,68 @@
package io.mockk.it

import io.mockk.*
import kotlin.test.Test
import kotlin.test.assertFalse
import kotlin.test.assertTrue

class MockTypesTest {
enum class MockType {
REGULAR,
SPY,
OBJECT,
STATIC,
CONSTRUCTOR
}

class TestCls

@Test
fun regularMock() {
assertOnlyOfType(mockk(), MockType.REGULAR)
}

@Test
fun spyMock() {
assertOnlyOfType(spyk(), MockType.SPY)
}

@Test
fun objectMock() {
val test = Any()
mockkObject(test) {
assertOnlyOfType(test, MockType.OBJECT)
}
}

@Test
fun staticMock() {
mockkStatic(TestCls::class) {
assertOnlyOfType(TestCls::class, MockType.STATIC)
}
}

@Test
fun constructorMock() {
mockkConstructor(TestCls::class) {
assertOnlyOfType(TestCls::class, MockType.CONSTRUCTOR)
}
}

fun assertOnlyOfType(mock: Any, singleType: MockType) {
for (type in MockType.values()) {
if (singleType == type) {
assertTrue(isOfMockType(mock, type), "mock is not of type $singleType")
} else {
assertFalse(isOfMockType(mock, type), "mock should be of type $singleType, but it is as well $type")
}
}
}

fun isOfMockType(mock: Any, type: MockType) = when (type) {
MockType.REGULAR -> isMockKMock(mock)
MockType.SPY -> isMockKMock(mock, regular = false, spy = true)
MockType.OBJECT -> isMockKMock(mock, regular = false, objectMock = true)
MockType.STATIC -> isMockKMock(mock, regular = false, staticMock = true)
MockType.CONSTRUCTOR -> isMockKMock(mock, regular = false, constructorMock = true)
}
}
4 changes: 4 additions & 0 deletions mockk/jvm/src/main/kotlin/io/mockk/impl/JvmMockKGateway.kt
Expand Up @@ -92,6 +92,10 @@ class JvmMockKGateway : MockKGateway {
gatewayAccessWithFactory
)

override val mockTypeChecker = JvmMockTypeChecker(
stubRepo,
constructorMockFactory::isMock
)

override fun verifier(params: VerificationParameters): CallVerifier {
val ordering = params.ordering
Expand Down

0 comments on commit 5efe6b6

Please sign in to comment.