diff --git a/src/commonMain/kotlin/com.koushikdutta.scratch/atomic/freezable.kt b/src/commonMain/kotlin/com.koushikdutta.scratch/atomic/freezable.kt index e45e2a2..d114e88 100644 --- a/src/commonMain/kotlin/com.koushikdutta.scratch/atomic/freezable.kt +++ b/src/commonMain/kotlin/com.koushikdutta.scratch/atomic/freezable.kt @@ -67,6 +67,18 @@ class FreezableReference: Freezable { return setInternal(value, false) } + fun compareAndSet(existing: FreezableValue, value: T, freeze: Boolean = false): Boolean { + if (existing.frozen == true) + return false + return atomicReference.compareAndSet(existing, FreezableValue(freeze, value)) + } + + fun compareAndSetNull(existing: FreezableValue): Boolean { + if (existing.frozen == true) + return false + return atomicReference.compareAndSet(existing, null) + } + private fun setInternal(value: T, freeze: Boolean): FreezableValue? { val newValue = FreezableValue(freeze, value) diff --git a/src/commonMain/kotlin/com.koushikdutta.scratch/baton.kt b/src/commonMain/kotlin/com.koushikdutta.scratch/baton.kt index 4cba463..c4b8964 100644 --- a/src/commonMain/kotlin/com.koushikdutta.scratch/baton.kt +++ b/src/commonMain/kotlin/com.koushikdutta.scratch/baton.kt @@ -24,6 +24,7 @@ private fun Continuation.resume(result: LockResult) { class BatonResult(throwable: Throwable?, value: T?, val resumed: Boolean, val finished: Boolean): LockResult(throwable, value) typealias BatonLock = (result: BatonResult) -> R typealias BatonTossLock = (result: BatonResult?) -> R +typealias BatonTakeCondition = (result: BatonResult) -> Boolean private fun BatonTossLock.resultInvoke(result: BatonResult?): LockResult { return try { @@ -83,10 +84,10 @@ private data class BatonWaiter(val continuation: Continuation?, val dat } } -class Baton() { +class Baton { private val freeze = FreezableReference>() - private fun passInternal(throwable: Throwable?, value: T?, lock: BatonTossLock? = null, take: Boolean = false, finish: Boolean = false, continuation: Continuation? = null): R? { + private fun passInternal(throwable: Throwable?, value: T?, lock: BatonTossLock? = null, take: BatonTakeCondition? = null, finish: Boolean = false, continuation: Continuation? = null): R? { val immediate = continuation == null val dataLock = if (immediate) null else lock val cdata = if (finish) { @@ -97,15 +98,33 @@ class Baton() { else BatonContinuationLockedData(null, null, false) } + else if (take != null) { + val taken: BatonContinuationLockedData + while (true) { + val found = freeze.get() + if (found == null) { + taken = BatonContinuationLockedData(null, null, false) + break + } + if (take(BatonResult(found.value.data.throwable, found.value.data.value, true, found.frozen))) { + if (freeze.compareAndSetNull(found)) { + taken = found.value.getContinuationLockedData() + break + } + } + else { + taken = BatonContinuationLockedData(null, null, false) + break + } + } + taken + } else { val resume = freeze.nullSwap() if (resume != null) { // fast path with no extra allocations in case there's a waiter resume.value.getContinuationLockedData() } - else if (take) { - BatonContinuationLockedData(null, null, false) - } else { // slow path with allocations and spin lock val waiter = freeze.swapIfNullElseNull(BatonWaiter(continuation, BatonData(throwable, value, dataLock), false)) @@ -147,20 +166,29 @@ class Baton() { it?.value } + fun takeIf(value: T, takeCondition: BatonTakeCondition): T? { + return passInternal(null, value, take = takeCondition, lock = defaultTossLock) + } + + fun takeIf(value: T, takeCondition: BatonTakeCondition, tossLock: BatonTossLock): R { + return passInternal(null, value, take = takeCondition, lock = tossLock)!! + } + + private val defaultTakeCondition: BatonTakeCondition = { true } fun take(value: T): T? { - return passInternal(null, value, take = true, lock = defaultTossLock) + return passInternal(null, value, take = defaultTakeCondition, lock = defaultTossLock) } fun takeRaise(throwable: Throwable): T? { - return passInternal(throwable, null, take = true, lock = defaultTossLock) + return passInternal(throwable, null, take = defaultTakeCondition, lock = defaultTossLock) } fun take(value: T, tossLock: BatonTossLock): R { - return passInternal(null, value, take = true, lock = tossLock)!! + return passInternal(null, value, take = defaultTakeCondition, lock = tossLock)!! } fun takeRaise(throwable: Throwable, tossLock: BatonTossLock): R { - return passInternal(throwable, null, take = true, lock = tossLock)!! + return passInternal(throwable, null, take = defaultTakeCondition, lock = tossLock)!! } fun toss(value: T): T? { diff --git a/src/commonTest/kotlin/com/koushikdutta/scratch/BatonTests.kt b/src/commonTest/kotlin/com/koushikdutta/scratch/BatonTests.kt index 89a74f5..9b6f8cc 100644 --- a/src/commonTest/kotlin/com/koushikdutta/scratch/BatonTests.kt +++ b/src/commonTest/kotlin/com/koushikdutta/scratch/BatonTests.kt @@ -249,9 +249,18 @@ class BatonTests { @Test fun testBatonTake() { val baton = Baton() - assertNull(baton.take(2), null) + assertEquals(baton.take(2), null) assertNull(baton.toss(3)) assertEquals(baton.take(4), 3) - assertNull(baton.take(5), null) + assertEquals(baton.take(5), null) + } + + @Test + fun testBatonTakeIf() { + val baton = Baton() + assertEquals(baton.take(2), null) + assertNull(baton.toss(3)) + assertNull(baton.takeIf(4) { it.value == 5 }) + assertEquals(baton.take(5), 3) } } \ No newline at end of file