Skip to content

Commit

Permalink
baton takeif
Browse files Browse the repository at this point in the history
  • Loading branch information
koush committed Jan 24, 2020
1 parent 5524210 commit 159f031
Show file tree
Hide file tree
Showing 3 changed files with 60 additions and 11 deletions.
12 changes: 12 additions & 0 deletions src/commonMain/kotlin/com.koushikdutta.scratch/atomic/freezable.kt
Expand Up @@ -67,6 +67,18 @@ class FreezableReference<T>: Freezable {
return setInternal(value, false)
}

fun compareAndSet(existing: FreezableValue<T>, value: T, freeze: Boolean = false): Boolean {
if (existing.frozen == true)
return false
return atomicReference.compareAndSet(existing, FreezableValue(freeze, value))
}

fun compareAndSetNull(existing: FreezableValue<T>): Boolean {
if (existing.frozen == true)
return false
return atomicReference.compareAndSet(existing, null)
}

private fun setInternal(value: T, freeze: Boolean): FreezableValue<T>? {
val newValue = FreezableValue(freeze, value)

Expand Down
46 changes: 37 additions & 9 deletions src/commonMain/kotlin/com.koushikdutta.scratch/baton.kt
Expand Up @@ -24,6 +24,7 @@ private fun <T> Continuation<T>.resume(result: LockResult<T>) {
class BatonResult<T>(throwable: Throwable?, value: T?, val resumed: Boolean, val finished: Boolean): LockResult<T>(throwable, value)
typealias BatonLock<T, R> = (result: BatonResult<T>) -> R
typealias BatonTossLock<T, R> = (result: BatonResult<T>?) -> R
typealias BatonTakeCondition<T> = (result: BatonResult<T>) -> Boolean

private fun <T, R> BatonTossLock<T, R>.resultInvoke(result: BatonResult<T>?): LockResult<R> {
return try {
Expand Down Expand Up @@ -83,10 +84,10 @@ private data class BatonWaiter<T, R>(val continuation: Continuation<R>?, val dat
}
}

class Baton<T>() {
class Baton<T> {
private val freeze = FreezableReference<BatonWaiter<T, *>>()

private fun <R> passInternal(throwable: Throwable?, value: T?, lock: BatonTossLock<T, R>? = null, take: Boolean = false, finish: Boolean = false, continuation: Continuation<R>? = null): R? {
private fun <R> passInternal(throwable: Throwable?, value: T?, lock: BatonTossLock<T, R>? = null, take: BatonTakeCondition<T>? = null, finish: Boolean = false, continuation: Continuation<R>? = null): R? {
val immediate = continuation == null
val dataLock = if (immediate) null else lock
val cdata = if (finish) {
Expand All @@ -97,15 +98,33 @@ class Baton<T>() {
else
BatonContinuationLockedData(null, null, false)
}
else if (take != null) {
val taken: BatonContinuationLockedData<T, out Any?>
while (true) {
val found = freeze.get()
if (found == null) {
taken = BatonContinuationLockedData(null, null, false)
break
}
if (take(BatonResult(found.value.data.throwable, found.value.data.value, true, found.frozen))) {
if (freeze.compareAndSetNull(found)) {
taken = found.value.getContinuationLockedData()
break
}
}
else {
taken = BatonContinuationLockedData(null, null, false)
break
}
}
taken
}
else {
val resume = freeze.nullSwap()
if (resume != null) {
// fast path with no extra allocations in case there's a waiter
resume.value.getContinuationLockedData()
}
else if (take) {
BatonContinuationLockedData(null, null, false)
}
else {
// slow path with allocations and spin lock
val waiter = freeze.swapIfNullElseNull(BatonWaiter(continuation, BatonData(throwable, value, dataLock), false))
Expand Down Expand Up @@ -147,20 +166,29 @@ class Baton<T>() {
it?.value
}

fun takeIf(value: T, takeCondition: BatonTakeCondition<T>): T? {
return passInternal(null, value, take = takeCondition, lock = defaultTossLock)
}

fun <R> takeIf(value: T, takeCondition: BatonTakeCondition<T>, tossLock: BatonTossLock<T, R>): R {
return passInternal(null, value, take = takeCondition, lock = tossLock)!!
}

private val defaultTakeCondition: BatonTakeCondition<T> = { true }
fun take(value: T): T? {
return passInternal(null, value, take = true, lock = defaultTossLock)
return passInternal(null, value, take = defaultTakeCondition, lock = defaultTossLock)
}

fun takeRaise(throwable: Throwable): T? {
return passInternal(throwable, null, take = true, lock = defaultTossLock)
return passInternal(throwable, null, take = defaultTakeCondition, lock = defaultTossLock)
}

fun <R> take(value: T, tossLock: BatonTossLock<T, R>): R {
return passInternal(null, value, take = true, lock = tossLock)!!
return passInternal(null, value, take = defaultTakeCondition, lock = tossLock)!!
}

fun <R> takeRaise(throwable: Throwable, tossLock: BatonTossLock<T, R>): R {
return passInternal(throwable, null, take = true, lock = tossLock)!!
return passInternal(throwable, null, take = defaultTakeCondition, lock = tossLock)!!
}

fun toss(value: T): T? {
Expand Down
13 changes: 11 additions & 2 deletions src/commonTest/kotlin/com/koushikdutta/scratch/BatonTests.kt
Expand Up @@ -249,9 +249,18 @@ class BatonTests {
@Test
fun testBatonTake() {
val baton = Baton<Int>()
assertNull(baton.take(2), null)
assertEquals(baton.take(2), null)
assertNull(baton.toss(3))
assertEquals(baton.take(4), 3)
assertNull(baton.take(5), null)
assertEquals(baton.take(5), null)
}

@Test
fun testBatonTakeIf() {
val baton = Baton<Int>()
assertEquals(baton.take(2), null)
assertNull(baton.toss(3))
assertNull(baton.takeIf(4) { it.value == 5 })
assertEquals(baton.take(5), 3)
}
}

0 comments on commit 159f031

Please sign in to comment.