Skip to content

Commit

Permalink
Add name API (#158)
Browse files Browse the repository at this point in the history
Add name API

Co-authored-by: Jake Wharton <jw@squareup.com>
  • Loading branch information
jingibus and JakeWharton committed Oct 11, 2022
1 parent 6bb1e72 commit b9ccd28
Show file tree
Hide file tree
Showing 7 changed files with 287 additions and 45 deletions.
32 changes: 19 additions & 13 deletions src/commonMain/kotlin/app/cash/turbine/Turbine.kt
Expand Up @@ -91,14 +91,19 @@ public operator fun <T> Turbine<T>.plusAssign(value: T) { add(value) }
*
* @param timeout If non-null, overrides the current Turbine timeout for this [Turbine]. See also:
* [withTurbineTimeout].
* @param name If non-null, name is added to any exceptions thrown to help identify which [Turbine] failed.
*/
@Suppress("FunctionName") // Interface constructor pattern.
public fun <T> Turbine(timeout: Duration? = null): Turbine<T> = ChannelTurbine(timeout = timeout)
public fun <T> Turbine(
timeout: Duration? = null,
name: String? = null,
): Turbine<T> = ChannelTurbine(timeout = timeout, name = name)

internal class ChannelTurbine<T>(
channel: Channel<T> = Channel(UNLIMITED),
private val job: Job? = null,
private val timeout: Duration?,
private val name: String?,
) : Turbine<T> {
private suspend fun <T> withTurbineTimeout(block: suspend () -> T): T {
return if (timeout != null) {
Expand Down Expand Up @@ -145,13 +150,13 @@ internal class ChannelTurbine<T>(
job?.cancel()
}

override fun takeEvent(): Event<T> = channel.takeEvent()
override fun takeEvent(): Event<T> = channel.takeEvent(name = name)

override fun takeItem(): T = channel.takeItem()
override fun takeItem(): T = channel.takeItem(name = name)

override fun takeComplete() = channel.takeComplete()
override fun takeComplete() = channel.takeComplete(name = name)

override fun takeError(): Throwable = channel.takeError()
override fun takeError(): Throwable = channel.takeError(name = name)

private var ignoreTerminalEvents = false
private var ignoreRemainingEvents = false
Expand All @@ -176,20 +181,20 @@ internal class ChannelTurbine<T>(
}

override fun expectNoEvents() {
channel.expectNoEvents()
channel.expectNoEvents(name = name)
}

override fun expectMostRecentItem(): T = channel.expectMostRecentItem()
override fun expectMostRecentItem(): T = channel.expectMostRecentItem(name = name)

override suspend fun awaitEvent(): Event<T> = withTurbineTimeout { channel.awaitEvent() }
override suspend fun awaitEvent(): Event<T> = withTurbineTimeout { channel.awaitEvent(name = name) }

override suspend fun awaitItem(): T = withTurbineTimeout { channel.awaitItem() }
override suspend fun awaitItem(): T = withTurbineTimeout { channel.awaitItem(name = name) }

override suspend fun skipItems(count: Int) = withTurbineTimeout { channel.skipItems(count) }
override suspend fun skipItems(count: Int) = withTurbineTimeout { channel.skipItems(count, name) }

override suspend fun awaitComplete() = withTurbineTimeout { channel.awaitComplete() }
override suspend fun awaitComplete() = withTurbineTimeout { channel.awaitComplete(name = name) }

override suspend fun awaitError(): Throwable = withTurbineTimeout { channel.awaitError() }
override suspend fun awaitError(): Throwable = withTurbineTimeout { channel.awaitError(name = name) }

override fun ensureAllEventsConsumed() {
if (ignoreRemainingEvents) return
Expand All @@ -209,7 +214,8 @@ internal class ChannelTurbine<T>(
if (unconsumed.isNotEmpty()) {
throw TurbineAssertionError(
buildString {
append("Unconsumed events found:")
append("Unconsumed events found".qualifiedBy(name))
append(":")
for (event in unconsumed) {
append("\n - $event")
}
Expand Down
62 changes: 35 additions & 27 deletions src/commonMain/kotlin/app/cash/turbine/channel.kt
Expand Up @@ -42,7 +42,7 @@ import kotlinx.coroutines.withTimeout
*
* @throws AssertionError if no item was emitted.
*/
public fun <T> ReceiveChannel<T>.expectMostRecentItem(): T {
public fun <T> ReceiveChannel<T>.expectMostRecentItem(name: String? = null): T {
var previous: ChannelResult<T>? = null
while (true) {
val current = tryReceive()
Expand All @@ -55,7 +55,7 @@ public fun <T> ReceiveChannel<T>.expectMostRecentItem(): T {

if (previous?.isSuccess == true) return previous.getOrThrow()

throw AssertionError("No item was found")
throw AssertionError("No item was found".qualifiedBy(name))
}

/**
Expand All @@ -66,9 +66,9 @@ public fun <T> ReceiveChannel<T>.expectMostRecentItem(): T {
*
* @throws AssertionError if unconsumed events are found.
*/
public fun <T> ReceiveChannel<T>.expectNoEvents() {
public fun <T> ReceiveChannel<T>.expectNoEvents(name: String? = null) {
val result = tryReceive()
if (!result.isFailure) result.unexpectedResult("no events")
if (!result.isFailure) result.unexpectedResult(name, "no events")
}

/**
Expand All @@ -77,17 +77,17 @@ public fun <T> ReceiveChannel<T>.expectNoEvents() {
*
* This function will always return a terminal event on a closed [ReceiveChannel].
*/
public suspend fun <T> ReceiveChannel<T>.awaitEvent(): Event<T> {
public suspend fun <T> ReceiveChannel<T>.awaitEvent(name: String? = null): Event<T> {
val timeout = contextTimeout()
return try {
withAppropriateTimeout(timeout) {
val item = receive()
Event.Item(item)
}
} catch (e: TimeoutCancellationException) {
throw AssertionError("No value produced in $timeout")
throw AssertionError("No ${"value produced".qualifiedBy(name)} in $timeout")
} catch (e: TurbineTimeoutCancellationException) {
throw AssertionError("No value produced in $timeout")
throw AssertionError("No ${"value produced".qualifiedBy(name)} in $timeout")
} catch (e: CancellationException) {
throw e
} catch (e: ClosedReceiveChannelException) {
Expand Down Expand Up @@ -139,10 +139,10 @@ internal class TurbineTimeoutCancellationException internal constructor(
*
* @throws AssertionError if the next event was completion or an error.
*/
public fun <T> ReceiveChannel<T>.takeEvent(): Event<T> {
public fun <T> ReceiveChannel<T>.takeEvent(name: String? = null): Event<T> {
assertCallingContextIsNotSuspended()
return takeEventUnsafe()
?: unexpectedEvent(null, "an event")
?: unexpectedEvent(name, null, "an event")
}

internal fun <T> ReceiveChannel<T>.takeEventUnsafe(): Event<T>? {
Expand All @@ -155,9 +155,9 @@ internal fun <T> ReceiveChannel<T>.takeEventUnsafe(): Event<T>? {
*
* @throws AssertionError if the next event was completion or an error, or no event.
*/
public fun <T> ReceiveChannel<T>.takeItem(): T {
public fun <T> ReceiveChannel<T>.takeItem(name: String? = null): T {
val event = takeEvent()
return (event as? Event.Item)?.value ?: unexpectedEvent(event, "item")
return (event as? Event.Item)?.value ?: unexpectedEvent(name, event, "item")
}

/**
Expand All @@ -166,9 +166,9 @@ public fun <T> ReceiveChannel<T>.takeItem(): T {
*
* @throws AssertionError if the next event was completion or an error.
*/
public fun <T> ReceiveChannel<T>.takeComplete() {
public fun <T> ReceiveChannel<T>.takeComplete(name: String? = null) {
val event = takeEvent()
if (event !is Event.Complete) unexpectedEvent(event, "complete")
if (event !is Event.Complete) unexpectedEvent(name, event, "complete")
}

/**
Expand All @@ -177,9 +177,9 @@ public fun <T> ReceiveChannel<T>.takeComplete() {
*
* @throws AssertionError if the next event was completion or an error.
*/
public fun <T> ReceiveChannel<T>.takeError(): Throwable {
public fun <T> ReceiveChannel<T>.takeError(name: String? = null): Throwable {
val event = takeEvent()
return (event as? Event.Error)?.throwable ?: unexpectedEvent(event, "error")
return (event as? Event.Error)?.throwable ?: unexpectedEvent(name, event, "error")
}

/**
Expand All @@ -188,10 +188,10 @@ public fun <T> ReceiveChannel<T>.takeError(): Throwable {
*
* @throws AssertionError if the next event was completion or an error.
*/
public suspend fun <T> ReceiveChannel<T>.awaitItem(): T =
when (val result = awaitEvent()) {
public suspend fun <T> ReceiveChannel<T>.awaitItem(name: String? = null): T =
when (val result = awaitEvent(name = name)) {
is Event.Item -> result.value
else -> unexpectedEvent(result, "item")
else -> unexpectedEvent(name, result, "item")
}

/**
Expand All @@ -200,12 +200,12 @@ public suspend fun <T> ReceiveChannel<T>.awaitItem(): T =
*
* @throws AssertionError if one of the events was completion or an error.
*/
public suspend fun <T> ReceiveChannel<T>.skipItems(count: Int) {
public suspend fun <T> ReceiveChannel<T>.skipItems(count: Int, name: String? = null) {
repeat(count) { index ->
when (val event = awaitEvent()) {
Event.Complete, is Event.Error -> {
val cause = (event as? Event.Error)?.throwable
throw TurbineAssertionError("Expected $count items but got $index items and $event", cause)
throw TurbineAssertionError("Expected $count ${"items".qualifiedBy(name)} but got $index items and $event", cause)
}
is Event.Item<T> -> {
// Success
Expand All @@ -220,10 +220,10 @@ public suspend fun <T> ReceiveChannel<T>.skipItems(count: Int) {
*
* @throws AssertionError if the next event was an item or an error.
*/
public suspend fun <T> ReceiveChannel<T>.awaitComplete() {
public suspend fun <T> ReceiveChannel<T>.awaitComplete(name: String? = null) {
val event = awaitEvent()
if (event != Event.Complete) {
unexpectedEvent(event, "complete")
unexpectedEvent(name, event, "complete")
}
}

Expand All @@ -233,10 +233,10 @@ public suspend fun <T> ReceiveChannel<T>.awaitComplete() {
*
* @throws AssertionError if the next event was an item or completion.
*/
public suspend fun <T> ReceiveChannel<T>.awaitError(): Throwable {
public suspend fun <T> ReceiveChannel<T>.awaitError(name: String? = null): Throwable {
val event = awaitEvent()
return (event as? Event.Error)?.throwable
?: unexpectedEvent(event, "error")
?: unexpectedEvent(name, event, "error")
}

internal fun <T> ChannelResult<T>.toEvent(): Event<T>? {
Expand All @@ -249,10 +249,18 @@ internal fun <T> ChannelResult<T>.toEvent(): Event<T>? {
}
}

private fun <T> ChannelResult<T>.unexpectedResult(expected: String): Nothing = unexpectedEvent(toEvent(), expected)
private fun <T> ChannelResult<T>.unexpectedResult(name: String?, expected: String): Nothing =
unexpectedEvent(name, toEvent(), expected)

private fun unexpectedEvent(event: Event<*>?, expected: String): Nothing {
private fun unexpectedEvent(name: String?, event: Event<*>?, expected: String): Nothing {
val cause = (event as? Event.Error)?.throwable
val eventAsString = event?.toString() ?: "no items"
throw TurbineAssertionError("Expected $expected but found $eventAsString", cause)
throw TurbineAssertionError("Expected ${expected.qualifiedBy(name)} but found $eventAsString", cause)
}

internal fun String.qualifiedBy(name: String?) =
if (name == null) {
this
} else {
"$this for $name"
}
15 changes: 10 additions & 5 deletions src/commonMain/kotlin/app/cash/turbine/flow.kt
Expand Up @@ -48,10 +48,11 @@ import kotlinx.coroutines.test.UnconfinedTestDispatcher
*/
public suspend fun <T> Flow<T>.test(
timeout: Duration? = null,
name: String? = null,
validate: suspend ReceiveTurbine<T>.() -> Unit,
) {
coroutineScope {
collectTurbineIn(this, null).apply {
collectTurbineIn(this, null, name).apply {
if (timeout != null) {
withTurbineTimeout(timeout) {
validate()
Expand Down Expand Up @@ -83,13 +84,17 @@ public suspend fun <T> Flow<T>.test(
* @param timeout If non-null, overrides the current Turbine timeout for this [Turbine]. See also:
* [withTurbineTimeout].
*/
public fun <T> Flow<T>.testIn(scope: CoroutineScope, timeout: Duration? = null): ReceiveTurbine<T> {
public fun <T> Flow<T>.testIn(
scope: CoroutineScope,
timeout: Duration? = null,
name: String? = null,
): ReceiveTurbine<T> {
if (timeout != null) {
// Eager check to throw early rather than in a subsequent 'await' call.
checkTimeout(timeout)
}

val turbine = collectTurbineIn(scope, timeout)
val turbine = collectTurbineIn(scope, timeout, name)

scope.coroutineContext.job.invokeOnCompletion { exception ->
if (debug) println("Scope ending ${exception ?: ""}")
Expand All @@ -104,7 +109,7 @@ public fun <T> Flow<T>.testIn(scope: CoroutineScope, timeout: Duration? = null):
}

@OptIn(ExperimentalCoroutinesApi::class) // New kotlinx.coroutines test APIs are not stable 😬
private fun <T> Flow<T>.collectTurbineIn(scope: CoroutineScope, timeout: Duration?): Turbine<T> {
private fun <T> Flow<T>.collectTurbineIn(scope: CoroutineScope, timeout: Duration?, name: String?): Turbine<T> {
lateinit var channel: Channel<T>

// Use test-specific unconfined if test scheduler is in use to inherit its virtual time.
Expand All @@ -116,7 +121,7 @@ private fun <T> Flow<T>.collectTurbineIn(scope: CoroutineScope, timeout: Duratio
channel = collectIntoChannel(this)
}

return ChannelTurbine(channel, job, timeout)
return ChannelTurbine(channel, job, timeout, name)
}

internal fun <T> Flow<T>.collectIntoChannel(scope: CoroutineScope): Channel<T> {
Expand Down
63 changes: 63 additions & 0 deletions src/commonTest/kotlin/app/cash/turbine/ChannelTest.kt
Expand Up @@ -108,6 +108,14 @@ class ChannelTest {
assertEquals(3, channel.awaitItem())
}

@Test fun skipItemsThrowsOnComplete() = runTest {
val channel = flowOf(1, 2).collectIntoChannel(this)
val message = assertFailsWith<AssertionError> {
channel.skipItems(3)
}.message
assertEquals("Expected 3 items but got 2 items and Complete", message)
}

@Test fun expectErrorOnCompletionBeforeAllItemsWereSkipped() = runTest {
val channel = flowOf(1).collectIntoChannel(this)
assertFailsWith<AssertionError> {
Expand Down Expand Up @@ -286,6 +294,61 @@ class ChannelTest {
assertSame(error, actual.cause)
}

@Test
fun expectMostRecentItemButNoItemWasFoundThrowsWithName() = runTest {
val actual = assertFailsWith<AssertionError> {
val channel = emptyFlow<Any>().collectIntoChannel(this)
channel.expectMostRecentItem(name = "empty flow")
}
assertEquals("No item was found for empty flow", actual.message)
}

@Test fun awaitItemButWasCloseThrowsWithName() = runTest {
val actual = assertFailsWith<AssertionError> {
emptyFlow<Unit>().collectIntoChannel(this).awaitItem(name = "closed flow")
}
assertEquals("Expected item for closed flow but found Complete", actual.message)
}

@Test fun awaitCompleteButWasItemThrowsWithName() = runTest {
val actual = assertFailsWith<AssertionError> {
flowOf("item!").collectIntoChannel(this)
.awaitComplete(name = "item flow")
}
assertEquals("Expected complete for item flow but found Item(item!)", actual.message)
}

@Test fun awaitErrorButWasItemThrowsWithName() = runTest {
val actual = assertFailsWith<AssertionError> {
flowOf("item!").collectIntoChannel(this).awaitError(name = "item flow")
}
assertEquals("Expected error for item flow but found Item(item!)", actual.message)
}

@Test fun awaitHonorsCoroutineContextTimeoutTimeoutWithName() = runTest {
val actual = assertFailsWith<AssertionError> {
withTurbineTimeout(10.milliseconds) {
neverFlow().collectIntoChannel(this).awaitItem(name = "never flow")
}
}
assertEquals("No value produced for never flow in 10ms", actual.message)
}

@Test fun takeItemButWasCloseThrowsWithName() = withTestScope {
val actual = assertFailsWith<AssertionError> {
emptyFlow<Unit>().collectIntoChannel(this).takeItem(name = "empty flow")
}
assertEquals("Expected item for empty flow but found Complete", actual.message)
}

@Test fun skipItemsThrowsOnCompleteWithName() = runTest {
val channel = flowOf(1, 2).collectIntoChannel(this)
val message = assertFailsWith<AssertionError> {
channel.skipItems(3, name = "two item channel")
}.message
assertEquals("Expected 3 items for two item channel but got 2 items and Complete", message)
}

/**
* Used to run test code with a [TestScope], but still outside a suspending context.
*/
Expand Down

0 comments on commit b9ccd28

Please sign in to comment.